1
0
forked from extern/smegmesh

Compare commits

...

37 Commits

Author SHA1 Message Date
1f0914e2df bugfix-node-not-leaving
- Add lock when perform synchronisation on concurrent access
2024-01-04 00:23:20 +00:00
27e00196cd main
- Not waiting in the waitgroup
2024-01-02 20:31:24 +00:00
dea6f1a22d main
- error in code invalid check for nil
2024-01-02 20:19:34 +00:00
913de57568 main
- Fixed bug
2024-01-02 20:11:11 +00:00
ce829114b1 bugfix
- on synchornisation node is not leaving mesh
2024-01-02 19:41:20 +00:00
cd844ff46e - Fixing DNS error 2024-01-02 00:15:23 +00:00
d0b1913796 74-perform-dad
- Fixing nil pointer dereference
2024-01-02 00:13:04 +00:00
90cfe820d2 - Fixing errors with stale paths 2024-01-02 00:09:31 +00:00
8a49809855 74-perform-dad
- Adding go.sum to fix errors
2024-01-01 23:59:04 +00:00
dbc18bddc6 74-perform-dad
- Performing DAD to check if IPv6 address present before adding
  outselves to mesh
- Changing name from wgmesh to smegmesh
2024-01-01 23:55:50 +00:00
36e82dba47 72-pull-rate-in-configuration
- Refactored pull rate into the configuration
- code freeze so no more code changes
2023-12-31 14:25:06 +00:00
3cc87bc252 72-pull-rate-in-configuration
- Updated examples
2023-12-31 12:47:45 +00:00
a9ed7c0a20 72-pull-rate-in-configuration
- Removing libp2p reference
2023-12-31 12:47:45 +00:00
fd29af73e3 72-pull-rate-in-configuration
- Added pull rate to configuration (finally) so this can
be modified by an administrator.
2023-12-31 12:47:45 +00:00
9e1058e0f2 72-pull-rate-in-configuration
- Added the pull rate to the configuration file
2023-12-31 12:47:45 +00:00
c29eb197f3 Merge pull request #71 from tim-beatham/66-ipv6-address-not-conforming-to-spec
66 ipv6 address not conforming to spec
2023-12-30 22:26:53 +00:00
1a9d9d61ad 66-ipv6-address-not-conforming-to-spec
- Missing commit
2023-12-30 22:26:08 +00:00
6954608c32 66-ipv6-address-not-confirming-to-spec
- UUID is not random just a name generator needs changing to shortuuid
- When in multiple meshes there is no wait group
2023-12-30 22:24:43 +00:00
2e6aed6f93 main
- Fixing issue with nil pointer de-reference due to bad design of mesh
  manager.
- Going forward all references to GetSelf should be depracated. It
  introduces a race condition when leaving a mesh network
2023-12-30 00:44:57 +00:00
b0893a0b8e Merge pull request #69 from tim-beatham/60-unit-test-crdt-data-store
60-unit-test-crdt-data-store
2023-12-29 22:06:20 +00:00
e7d6055fa3 60-unit-test-crdt-data-store
Provided unit tests for datastore.go
And fixed unit tets failing by different way of providing CA
2023-12-29 22:05:05 +00:00
e0f3f116b9 main
- Stale serverConfig entry causing certificate authorities
to not become authorised
2023-12-29 19:54:08 +00:00
352648b7cb main
- Fixed problem where connection not removed on error
2023-12-29 11:12:40 +00:00
2d5df25b1d main
- If deadline exceeded error remove connection from
connection manager
2023-12-29 01:29:11 +00:00
cabe173831 main
Adding retry parameter
2023-12-29 01:10:26 +00:00
d2c8a52ec6 main
- Adding retry policy for mobility
2023-12-29 00:58:43 +00:00
bf53108384 main
- Bugfix, fix consistent hash problem where
if failure happens then causes panic
2023-12-28 23:24:38 +00:00
77aac5534b main
- Bugfix in client where "-" was attempted to be parsed as a UDP addr
2023-12-28 17:46:04 +00:00
58439fcd56 main
- Bugfix when keepalivewg is not set causes segmentation fault
- give keepalive a default value of 0 if not set
2023-12-28 17:32:54 +00:00
311a15363a Merge pull request #67 from tim-beatham/66-improve-graph-dot-tool
66 improve graph dot tool
2023-12-25 01:26:15 +00:00
255d3c8b39 66-improve-graph-dot-tool
- Showing services a node provides
- Showing all meshes not just one
- Showing the default route
2023-12-25 01:25:20 +00:00
41899c5831 66-improve-graph-dot-tool
Improving the graph dot tool so that it shows all
meshes
2023-12-25 01:10:11 +00:00
fe4ca66ff6 Merge pull request #65 from tim-beatham/64-2p-set-unit-test
64 2p set unit test
2023-12-22 23:58:59 +00:00
0b91ba744a 61-improve-unit-test-coverage
- Provided unit tests for g_map and 2p_map
2023-12-22 23:57:10 +00:00
67483c2a90 64-unit-test-two-phase-set
Provide unit tests for two phase set to make it more
transparent what exactly they are doing.
2023-12-22 23:57:10 +00:00
af26e81bd3 Merge pull request #63 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:52:46 +00:00
186acbe915 Merge pull request #62 from tim-beatham/61-improve-unit-testing-coverage
61-improve-unit-testing-coverage
2023-12-22 21:49:06 +00:00
65 changed files with 2062 additions and 806 deletions

BIN
api Executable file

Binary file not shown.

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
"github.com/tim-beatham/wgmesh/pkg/api" "github.com/tim-beatham/smegmesh/pkg/api"
) )
func main() { func main() {

View File

@ -3,7 +3,7 @@ package main
import ( import (
"log" "log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns" smegdns "github.com/tim-beatham/smegmesh/pkg/dns"
) )
func main() { func main() {

View File

@ -6,8 +6,10 @@ import (
"os" "os"
"github.com/akamensky/argparse" "github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log" graph "github.com/tim-beatham/smegmesh/pkg/dot"
"github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"
@ -20,25 +22,22 @@ type CreateMeshParams struct {
AdvertiseDefault bool AdvertiseDefault bool
} }
func createMesh(params *CreateMeshParams) string { func createMesh(client *ipc.SmegmeshIpc, args *ipc.NewMeshArgs) {
var reply string var reply string
newMeshParams := ipc.NewMeshArgs{ err := client.CreateMesh(args, &reply)
WgArgs: params.WgArgs,
}
err := params.Client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
if err != nil { if err != nil {
return err.Error() fmt.Println(err.Error())
return
} }
return reply fmt.Println(reply)
} }
func listMeshes(client *ipcRpc.Client) { func listMeshes(client *ipc.SmegmeshIpc) {
reply := new(ipc.ListMeshReply) reply := new(ipc.ListMeshReply)
err := client.Call("IpcHandler.ListMeshes", "", &reply) err := client.ListMeshes(reply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -50,38 +49,22 @@ func listMeshes(client *ipcRpc.Client) {
} }
} }
type JoinMeshParams struct { func joinMesh(client *ipc.SmegmeshIpc, args ipc.JoinMeshArgs) {
Client *ipcRpc.Client
MeshId string
IpAddress string
Endpoint string
WgArgs ipc.WireGuardArgs
AdvertiseRoutes bool
AdvertiseDefault bool
}
func joinMesh(params *JoinMeshParams) string {
var reply string var reply string
args := ipc.JoinMeshArgs{ err := client.JoinMesh(args, &reply)
MeshId: params.MeshId,
IpAdress: params.IpAddress,
WgArgs: params.WgArgs,
}
err := params.Client.Call("IpcHandler.JoinMesh", &args, &reply)
if err != nil { if err != nil {
return err.Error() fmt.Println(err.Error())
} }
return reply fmt.Println(reply)
} }
func leaveMesh(client *ipcRpc.Client, meshId string) { func leaveMesh(client *ipc.SmegmeshIpc, meshId string) {
var reply string var reply string
err := client.Call("IpcHandler.LeaveMesh", &meshId, &reply) err := client.LeaveMesh(meshId, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -91,10 +74,51 @@ func leaveMesh(client *ipcRpc.Client, meshId string) {
fmt.Println(reply) fmt.Println(reply)
} }
func getGraph(client *ipcRpc.Client, meshId string) { func getGraph(client *ipc.SmegmeshIpc) {
listMeshesReply := new(ipc.ListMeshReply)
err := client.ListMeshes(listMeshesReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes := make(map[string][]ctrlserver.MeshNode)
for _, meshId := range listMeshesReply.Meshes {
var meshReply ipc.GetMeshReply
err := client.GetMesh(meshId, &meshReply)
if err != nil {
fmt.Println(err.Error())
return
}
meshes[meshId] = meshReply.Nodes
}
dotGenerator := graph.NewMeshGraphConverter(meshes)
dot, err := dotGenerator.Generate()
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(dot)
}
func queryMesh(client *ipc.SmegmeshIpc, meshId, query string) {
var reply string var reply string
err := client.Call("IpcHandler.GetDOT", &meshId, &reply) args := ipc.QueryMesh{
MeshId: meshId,
Query: query,
}
err := client.Query(args, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -104,24 +128,13 @@ func getGraph(client *ipcRpc.Client, meshId string) {
fmt.Println(reply) fmt.Println(reply)
} }
func queryMesh(client *ipcRpc.Client, meshId, query string) { func putDescription(client *ipc.SmegmeshIpc, meshId, description string) {
var reply string var reply string
err := client.Call("IpcHandler.Query", &ipc.QueryMesh{MeshId: meshId, Query: query}, &reply) err := client.PutDescription(ipc.PutDescriptionArgs{
MeshId: meshId,
if err != nil { Description: description,
fmt.Println(err.Error()) }, &reply)
return
}
fmt.Println(reply)
}
// putDescription: puts updates the description about the node to the meshes
func putDescription(client *ipcRpc.Client, description string) {
var reply string
err := client.Call("IpcHandler.PutDescription", &description, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -132,10 +145,13 @@ func putDescription(client *ipcRpc.Client, description string) {
} }
// putAlias: puts an alias for the node // putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) { func putAlias(client *ipc.SmegmeshIpc, meshid, alias string) {
var reply string var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply) err := client.PutAlias(ipc.PutAliasArgs{
MeshId: meshid,
Alias: alias,
}, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -145,15 +161,14 @@ func putAlias(client *ipcRpc.Client, alias string) {
fmt.Println(reply) fmt.Println(reply)
} }
func setService(client *ipcRpc.Client, service, value string) { func setService(client *ipc.SmegmeshIpc, meshId, service, value string) {
var reply string var reply string
serviceArgs := &ipc.PutServiceArgs{ err := client.PutService(ipc.PutServiceArgs{
MeshId: meshId,
Service: service, Service: service,
Value: value, Value: value,
} }, &reply)
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -163,10 +178,13 @@ func setService(client *ipcRpc.Client, service, value string) {
fmt.Println(reply) fmt.Println(reply)
} }
func deleteService(client *ipcRpc.Client, service string) { func deleteService(client *ipc.SmegmeshIpc, meshId, service string) {
var reply string var reply string
err := client.Call("IpcHandler.PutService", &service, &reply) err := client.DeleteService(ipc.DeleteServiceArgs{
MeshId: meshId,
Service: service,
}, &reply)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
@ -177,8 +195,8 @@ func deleteService(client *ipcRpc.Client, service string) {
} }
func main() { func main() {
parser := argparse.NewParser("wg-mesh", parser := argparse.NewParser("smgctl",
"wg-mesh Manipulate WireGuard mesh networks") "smegctl Manipulate WireGuard mesh networks")
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh") newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to") listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
@ -201,7 +219,6 @@ func main() {
}) })
var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ var newMeshRole *string = newMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol", " protocol",
@ -234,7 +251,6 @@ func main() {
}) })
var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{ var joinMeshRole *string = joinMeshCmd.Selector("r", "role", []string{"peer", "client"}, &argparse.Options{
Default: "peer",
Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" + Help: "Role in the mesh network. A value of peer means that the node is publicly routeable and thus considered" +
" in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" + " in the gossip protocol. Client means that the node is not publicly routeable and is not a candidate in the gossip" +
" protocol", " protocol",
@ -258,11 +274,6 @@ func main() {
Help: "Advertise ::/0 into the mesh network", Help: "Advertise ::/0 into the mesh network",
}) })
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{
Required: true,
Help: "MeshID of the graph to get",
})
var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{ var leaveMeshMeshId *string = leaveMeshCmd.String("m", "mesh", &argparse.Options{
Required: true, Required: true,
Help: "MeshID of the mesh to leave", Help: "MeshID of the mesh to leave",
@ -282,6 +293,16 @@ func main() {
Help: "Description of the node in the mesh", Help: "Description of the node in the mesh",
}) })
var descriptionMeshId *string = putDescriptionCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var aliasMeshId *string = putAliasCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{ var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{
Required: true, Required: true,
Help: "Alias of the node to set can be used in DNS to lookup an IP address", Help: "Alias of the node to set can be used in DNS to lookup an IP address",
@ -296,11 +317,21 @@ func main() {
Help: "Value of the service to advertise in the mesh network", Help: "Value of the service to advertise in the mesh network",
}) })
var serviceMeshId *string = setServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{ var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{
Required: true, Required: true,
Help: "Key of the service to remove", Help: "Key of the service to remove",
}) })
var deleteServiceMeshid *string = deleteServiceCmd.String("m", "meshid", &argparse.Options{
Required: true,
Help: "MeshID of the mesh network to join",
})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@ -308,16 +339,13 @@ func main() {
return return
} }
client, err := ipcRpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
fmt.Println(err.Error()) panic(err)
return
} }
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{ args := &ipc.NewMeshArgs{
Client: client,
Endpoint: *newMeshEndpoint,
WgArgs: ipc.WireGuardArgs{ WgArgs: ipc.WireGuardArgs{
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
Role: *newMeshRole, Role: *newMeshRole,
@ -326,7 +354,9 @@ func main() {
AdvertiseDefaultRoute: *newMeshAdvertiseDefaults, AdvertiseDefaultRoute: *newMeshAdvertiseDefaults,
AdvertiseRoutes: *newMeshAdvertiseRoutes, AdvertiseRoutes: *newMeshAdvertiseRoutes,
}, },
})) }
createMesh(client, args)
} }
if listMeshCmd.Happened() { if listMeshCmd.Happened() {
@ -334,11 +364,9 @@ func main() {
} }
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{ args := ipc.JoinMeshArgs{
Client: client,
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,
Endpoint: *joinMeshEndpoint,
WgArgs: ipc.WireGuardArgs{ WgArgs: ipc.WireGuardArgs{
Endpoint: *joinMeshEndpoint, Endpoint: *joinMeshEndpoint,
Role: *joinMeshRole, Role: *joinMeshRole,
@ -347,11 +375,12 @@ func main() {
AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults, AdvertiseDefaultRoute: *joinMeshAdvertiseDefaults,
AdvertiseRoutes: *joinMeshAdvertiseRoutes, AdvertiseRoutes: *joinMeshAdvertiseRoutes,
}, },
})) }
joinMesh(client, args)
} }
if getGraphCmd.Happened() { if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId) getGraph(client)
} }
if leaveMeshCmd.Happened() { if leaveMeshCmd.Happened() {
@ -363,18 +392,18 @@ func main() {
} }
if putDescriptionCmd.Happened() { if putDescriptionCmd.Happened() {
putDescription(client, *description) putDescription(client, *descriptionMeshId, *description)
} }
if putAliasCmd.Happened() { if putAliasCmd.Happened() {
putAlias(client, *alias) putAlias(client, *aliasMeshId, *alias)
} }
if setServiceCmd.Happened() { if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue) setService(client, *serviceMeshId, *serviceKey, *serviceValue)
} }
if deleteServiceCmd.Happened() { if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey) deleteService(client, *deleteServiceMeshid, *deleteServiceKey)
} }
} }

View File

@ -6,14 +6,14 @@ import (
"os" "os"
"os/signal" "os/signal"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver" ctrlserver "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/smegmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync" "github.com/tim-beatham/smegmesh/pkg/sync"
timer "github.com/tim-beatham/wgmesh/pkg/timers" timer "github.com/tim-beatham/smegmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@ -59,6 +59,11 @@ func main() {
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
if err != nil {
panic(err)
}
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester = sync.NewSyncRequester(ctrlServer) syncRequester = sync.NewSyncRequester(ctrlServer)
syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester) syncer = sync.NewSyncer(ctrlServer.MeshManager, conf, syncRequester)

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15 interClusterChance: 0.15
branchRate: 3 branchRate: 3
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 heartBeatTime: 10
pruneTime: 20 pruneTime: 20

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15 interClusterChance: 0.15
branchRate: 3 branchRate: 3
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 heartBeatTime: 10
pruneTime: 20 pruneTime: 20

View File

@ -10,5 +10,5 @@ syncRate: 1
interClusterChance: 0.15 interClusterChance: 0.15
branchRate: 3 branchRate: 3
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 heartBeatTime: 10
pruneTime: 20 pruneTime: 20

4
go.mod
View File

@ -1,4 +1,4 @@
module github.com/tim-beatham/wgmesh module github.com/tim-beatham/smegmesh
go 1.21.3 go 1.21.3
@ -11,11 +11,11 @@ require (
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5 github.com/jsimonetti/rtnetlink v1.3.5
github.com/lithammer/shortuuid v3.0.0+incompatible
github.com/miekg/dns v1.1.57 github.com/miekg/dns v1.1.57
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0 golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
gonum.org/v1/gonum v0.14.0
google.golang.org/grpc v1.58.1 google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1

6
go.sum
View File

@ -27,8 +27,6 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE= github.com/go-playground/validator/v10 v10.16.0 h1:x+plE831WK4vaKHO/jpgUGsvLKIqRRkz6M78GuJAfGE=
github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.16.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
@ -57,6 +55,8 @@ github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZX
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/lithammer/shortuuid v3.0.0+incompatible h1:NcD0xWW/MZYXEHa6ITy6kaXN5nwm/V115vj2YXfhS0w=
github.com/lithammer/shortuuid v3.0.0+incompatible/go.mod h1:FR74pbAuElzOUuenUHTK2Tciko1/vKuIKS9dSkDrA4w=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw=
@ -123,8 +123,6 @@ golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 h1:EY138uSo1JYlDq+
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
gonum.org/v1/gonum v0.14.0 h1:2NiG67LD1tEH0D7kM+ps2V+fXmsAnpUeec7n8tcr4S0=
gonum.org/v1/gonum v0.14.0/go.mod h1:AoWeoz0becf9QMWtE8iWXNXc27fK4fNeHNf/oMejGfU=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U= google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 h1:bVf09lpb+OJbByTj913DRJioFFAjf/ZGxEz7MajTp2U=
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM= google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98/go.mod h1:TUfxEVdsvPg18p6AslUXFoLdpED4oBnGwyqk3dV1XzM=
google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58= google.golang.org/grpc v1.58.1 h1:OL+Vz23DTtrrldqHK49FUOPHyY75rvFqJfXC84NYW58=

View File

@ -4,17 +4,13 @@ import (
"fmt" "fmt"
"net/http" "net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words" "github.com/tim-beatham/smegmesh/pkg/what8words"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface { type ApiServer interface {
GetMeshes(c *gin.Context) GetMeshes(c *gin.Context)
Run(addr string) error Run(addr string) error
@ -22,7 +18,7 @@ type ApiServer interface {
type SmegServer struct { type SmegServer struct {
router *gin.Engine router *gin.Engine
client *ipcRpc.Client client *ipc.SmegmeshIpc
words *what8words.What8Words words *what8words.What8Words
} }
@ -106,7 +102,7 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
var reply string var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply) err := s.client.CreateMesh(&ipcRequest, &reply)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
@ -132,8 +128,8 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
} }
ipcRequest := ipc.JoinMeshArgs{ ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId, MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap, IpAddress: joinMesh.Bootstrap,
WgArgs: ipc.WireGuardArgs{ WgArgs: ipc.WireGuardArgs{
WgPort: joinMesh.WgPort, WgPort: joinMesh.WgPort,
}, },
@ -141,7 +137,7 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
var reply string var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply) err := s.client.JoinMesh(ipcRequest, &reply)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
@ -164,7 +160,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
getMeshReply := new(ipc.GetMeshReply) getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply) err := s.client.GetMesh(meshid, getMeshReply)
if err != nil { if err != nil {
c.JSON(http.StatusNotFound, c.JSON(http.StatusNotFound,
@ -182,7 +178,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
func (s *SmegServer) GetMeshes(c *gin.Context) { func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply) listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply) err := s.client.ListMeshes(listMeshesReply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -195,7 +191,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
for _, mesh := range listMeshesReply.Meshes { for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply) getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply) err := s.client.GetMesh(mesh, getMeshReply)
if err != nil { if err != nil {
logging.Log.WriteErrorf(err.Error()) logging.Log.WriteErrorf(err.Error())
@ -215,7 +211,7 @@ func (s *SmegServer) Run(addr string) error {
} }
func NewSmegServer(conf ApiServerConf) (ApiServer, error) { func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
return nil, err return nil, err
@ -239,9 +235,19 @@ func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
words: words, words: words,
} }
router.GET("/meshes", smegServer.GetMeshes) v1 := router.Group("/api/v1")
router.GET("/mesh/:meshid", smegServer.GetMesh) {
router.POST("/mesh/create", smegServer.CreateMesh) meshes := v1.Group("/meshes")
router.POST("/mesh/join", smegServer.JoinMesh) {
meshes.GET("/", smegServer.GetMeshes)
}
mesh := v1.Group("/mesh")
{
mesh.GET("/:meshid", smegServer.GetMesh)
mesh.POST("/create", smegServer.CreateMesh)
mesh.POST("/join", smegServer.JoinMesh)
}
}
return smegServer, nil return smegServer, nil
} }

View File

@ -9,10 +9,10 @@ import (
"time" "time"
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )

View File

@ -2,7 +2,7 @@ package automerge
import ( import (
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
type AutomergeSync struct { type AutomergeSync struct {

View File

@ -6,9 +6,9 @@ import (
"testing" "testing"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )

View File

@ -3,9 +3,9 @@ package automerge
import ( import (
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
type CrdtProviderFactory struct{} type CrdtProviderFactory struct{}

View File

@ -47,7 +47,6 @@ type WgConfiguration struct {
// If the user is globaly accessible they specify themselves as a client. // If the user is globaly accessible they specify themselves as a client.
Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"` Role *NodeType `yaml:"role" validate:"required,eq=client|eq=peer"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers. // KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"` KeepAliveWg *int `yaml:"keepAliveWg" validate:"omitempty,gte=0"`
// PreUp are WireGuard commands to run before adding the WG interface // PreUp are WireGuard commands to run before adding the WG interface
PreUp []string `yaml:"preUp"` PreUp []string `yaml:"preUp"`
@ -77,11 +76,13 @@ type DaemonConfiguration struct {
Profile bool `yaml:"profile"` Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types // StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"` StubWg bool `yaml:"stubWg"`
// SyncRate specifies how long the minimum time should be between synchronisation // SyncTime specifies how long the minimum time should be between synchronisation
SyncRate int `yaml:"syncRate" validate:"required,gte=1"` SyncTime int `yaml:"syncTime" validate:"required,gte=1"`
// KeepAliveTime: number of seconds before the leader of the mesh sends an update to // PullTime specifies the interval between checking for configuration changes
PullTime int `yaml:"pullTime" validate:"gte=0"`
// HeartBeat: number of seconds before the leader of the mesh sends an update to
// send to every member in the mesh // send to every member in the mesh
KeepAliveTime int `yaml:"keepAliveTime" validate:"required,gte=1"` HeartBeat int `yaml:"heartBeatTime" validate:"required,gte=1"`
// ClusterSize specifies how many neighbours you should synchronise with per round // ClusterSize specifies how many neighbours you should synchronise with per round
ClusterSize int `yaml:"clusterSize" validate:"gte=1"` ClusterSize int `yaml:"clusterSize" validate:"gte=1"`
// InterClusterChance specifies the probabilityof inter-cluster communication in a sync round // InterClusterChance specifies the probabilityof inter-cluster communication in a sync round
@ -126,26 +127,6 @@ func ValidateDaemonConfiguration(c *DaemonConfiguration) error {
return err return err
} }
// ParseMeshConfiguration: parses the mesh network configuration. Parses parameters such as
// keepalive time, role and so forth.
func ParseMeshConfiguration(filePath string) (*WgConfiguration, error) {
var conf WgConfiguration
yamlBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
err = yaml.Unmarshal(yamlBytes, &conf)
if err != nil {
return nil, err
}
return &conf, ValidateMeshConfiguration(&conf)
}
// ParseDaemonConfiguration parses the mesh configuration and validates the configuration // ParseDaemonConfiguration parses the mesh configuration and validates the configuration
func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) { func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
var conf DaemonConfiguration var conf DaemonConfiguration
@ -162,6 +143,11 @@ func ParseDaemonConfiguration(filePath string) (*DaemonConfiguration, error) {
return nil, err return nil, err
} }
if conf.BaseConfiguration.KeepAliveWg == nil {
var keepAlive int = 0
conf.BaseConfiguration.KeepAliveWg = &keepAlive
}
return &conf, ValidateDaemonConfiguration(&conf) return &conf, ValidateDaemonConfiguration(&conf)
} }

View File

@ -21,11 +21,12 @@ func getExampleConfiguration() *DaemonConfiguration {
Timeout: 5, Timeout: 5,
Profile: false, Profile: false,
StubWg: false, StubWg: false,
SyncRate: 2, SyncTime: 2,
KeepAliveTime: 2, HeartBeat: 2,
ClusterSize: 64, ClusterSize: 64,
InterClusterChance: 0.15, InterClusterChance: 0.15,
BranchRate: 3, BranchRate: 3,
PullTime: 0,
InfectionCount: 2, InfectionCount: 2,
BaseConfiguration: WgConfiguration{ BaseConfiguration: WgConfiguration{
IPDiscovery: &discovery, IPDiscovery: &discovery,
@ -162,9 +163,9 @@ func TestBranchRateZero(t *testing.T) {
} }
} }
func TestSyncRateZero(t *testing.T) { func TestsyncTimeZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.SyncRate = 0 conf.SyncTime = 0
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)
@ -175,7 +176,7 @@ func TestSyncRateZero(t *testing.T) {
func TestKeepAliveTimeZero(t *testing.T) { func TestKeepAliveTimeZero(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
conf.KeepAliveTime = 0 conf.HeartBeat = 0
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)
if err == nil { if err == nil {
@ -215,6 +216,17 @@ func TestInfectionCountOne(t *testing.T) {
} }
} }
func TestPullTimeNegative(t *testing.T) {
conf := getExampleConfiguration()
conf.PullTime = -1
err := ValidateDaemonConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestValidConfiguration(t *testing.T) { func TestValidConfiguration(t *testing.T) {
conf := getExampleConfiguration() conf := getExampleConfiguration()
err := ValidateDaemonConfiguration(conf) err := ValidateDaemonConfiguration(conf)

View File

@ -23,9 +23,10 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
lower := 0 lower := 0
higher := len(global) - 1 higher := len(global) - 1
mid := (lower + higher) / 2
for (higher+1)-lower > groupSize { for (higher+1)-lower > groupSize {
mid := (lower + higher) / 2
if global[mid] < selfId { if global[mid] < selfId {
lower = mid + 1 lower = mid + 1
} else if global[mid] > selfId { } else if global[mid] > selfId {
@ -33,8 +34,6 @@ func binarySearch(global []string, selfId string, groupSize int) (int, int) {
} else { } else {
break break
} }
mid = (lower + higher) / 2
} }
return lower, int(math.Min(float64(lower+groupSize), float64(len(global)))) return lower, int(math.Min(float64(lower+groupSize), float64(len(global))))

View File

@ -6,7 +6,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )

View File

@ -7,7 +7,7 @@ import (
"os" "os"
"sync" "sync"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
// ConnectionManager defines an interface for maintaining peer connections // ConnectionManager defines an interface for maintaining peer connections
@ -19,9 +19,11 @@ type ConnectionManager interface {
// If the endpoint does not exist then add the connection. Returns an error // If the endpoint does not exist then add the connection. Returns an error
// if something went wrong // if something went wrong
GetConnection(endPoint string) (PeerConnection, error) GetConnection(endPoint string) (PeerConnection, error)
// HasConnections returns true if a client has already registered at the givne // HasConnections returns true if a peer has already registered at the given
// endpoint or false otherwise. // endpoint or false otherwise.
HasConnection(endPoint string) bool HasConnection(endPoint string) bool
// Removes a connection if it exists
RemoveConnection(endPoint string) error
// Goes through all the connections and closes eachone // Goes through all the connections and closes eachone
Close() error Close() error
} }
@ -32,7 +34,6 @@ type ConnectionManagerImpl struct {
// clientConnections maps an endpoint to a connection // clientConnections maps an endpoint to a connection
conLoc sync.RWMutex conLoc sync.RWMutex
clientConnections map[string]PeerConnection clientConnections map[string]PeerConnection
serverConfig *tls.Config
clientConfig *tls.Config clientConfig *tls.Config
connFactory PeerConnectionFactory connFactory PeerConnectionFactory
} }
@ -61,37 +62,25 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
return nil, err return nil, err
} }
serverAuth := tls.RequireAndVerifyClientCert
if params.SkipCertVerification {
serverAuth = tls.RequireAnyClientCert
}
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
if !params.SkipCertVerification { if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
if params.CaCert == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.CaCert)
if err != nil {
return nil, err
}
certPool.AppendCertsFromPEM(caCert)
} }
serverConfig := &tls.Config{ caCert, err := os.ReadFile(params.CaCert)
ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
} }
clientConfig := &tls.Config{ clientConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: params.SkipCertVerification, InsecureSkipVerify: params.SkipCertVerification,
Certificates: []tls.Certificate{cert},
RootCAs: certPool, RootCAs: certPool,
} }
@ -99,7 +88,6 @@ func NewConnectionManager(params *NewConnectionManagerParams) (ConnectionManager
connMgr := ConnectionManagerImpl{ connMgr := ConnectionManagerImpl{
sync.RWMutex{}, sync.RWMutex{},
connections, connections,
serverConfig,
clientConfig, clientConfig,
params.ConnFactory, params.ConnFactory,
} }
@ -150,6 +138,15 @@ func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
return exists return exists
} }
// RemoveConnection removes the given connection if it exists
func (m *ConnectionManagerImpl) RemoveConnection(endPoint string) error {
m.conLoc.Lock()
err := m.clientConnections[endPoint].Close()
delete(m.clientConnections, endPoint)
m.conLoc.Unlock()
return err
}
func (m *ConnectionManagerImpl) Close() error { func (m *ConnectionManagerImpl) Close() error {
for _, conn := range m.clientConnections { for _, conn := range m.clientConnections {
if err := conn.Close(); err != nil { if err := conn.Close(); err != nil {

View File

@ -53,13 +53,13 @@ func TestNewConnectionManagerCACertDoesNotExistAndVerify(t *testing.T) {
func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) { func TestNewConnectionManagerCACertDoesNotExistAndNotVerify(t *testing.T) {
params := getConnectionManagerParams() params := getConnectionManagerParams()
params.CaCert = "" params.CaCert = "./cert/sdjsdjsdjk.pem"
params.SkipCertVerification = true params.SkipCertVerification = true
_, err := NewConnectionManager(params) _, err := NewConnectionManager(params)
if err != nil { if err == nil {
t.Fatal(`an error should not be thrown`) t.Fatalf(`an error should be thrown`)
} }
} }

View File

@ -2,22 +2,23 @@ package conn
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509"
"errors"
"fmt" "fmt"
"net" "net"
"os"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
) )
// ConnectionServer manages gRPC server peer connections // ConnectionServer manages gRPC server peer connections
type ConnectionServer struct { type ConnectionServer struct {
// tlsConfiguration of the server
serverConfig *tls.Config
// server an instance of the grpc server // server an instance of the grpc server
server *grpc.Server // the authentication service to authenticate nodes server *grpc.Server
// the ctrl service to manage node // the ctrl service to manage node
ctrlProvider rpc.MeshCtrlServerServer ctrlProvider rpc.MeshCtrlServerServer
// the sync service to synchronise nodes // the sync service to synchronise nodes
@ -48,9 +49,26 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
serverAuth = tls.RequireAnyClientCert serverAuth = tls.RequireAnyClientCert
} }
certPool := x509.NewCertPool()
if params.Conf.CaCertificatePath == "" {
return nil, errors.New("CA Cert is not specified")
}
caCert, err := os.ReadFile(params.Conf.CaCertificatePath)
if err != nil {
return nil, err
}
if ok := certPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("could not parse PEM")
}
serverConfig := &tls.Config{ serverConfig := &tls.Config{
ClientAuth: serverAuth, ClientAuth: serverAuth,
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
ClientCAs: certPool,
} }
server := grpc.NewServer( server := grpc.NewServer(
@ -61,7 +79,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
syncProvider := params.SyncProvider syncProvider := params.SyncProvider
connServer := ConnectionServer{ connServer := ConnectionServer{
serverConfig: serverConfig,
server: server, server: server,
ctrlProvider: ctrlProvider, ctrlProvider: ctrlProvider,
syncProvider: syncProvider, syncProvider: syncProvider,
@ -74,7 +91,6 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
// Listen for incoming requests. Returns an error if something went wrong. // Listen for incoming requests. Returns an error if something went wrong.
func (s *ConnectionServer) Listen() error { func (s *ConnectionServer) Listen() error {
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider) rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
rpc.RegisterSyncServiceServer(s.server, s.syncProvider) rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort)) lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.Conf.GrpcPort))

View File

@ -16,6 +16,11 @@ func (s *ConnectionManagerStub) AddConnection(endPoint string) (PeerConnection,
return mock, nil return mock, nil
} }
func (s *ConnectionManagerStub) RemoveConnection(endPoint string) error {
delete(s.Endpoints, endPoint)
return nil
}
func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) { func (s *ConnectionManagerStub) GetConnection(endPoint string) (PeerConnection, error) {
endpoint, ok := s.Endpoints[endPoint] endpoint, ok := s.Endpoints[endPoint]

View File

@ -9,10 +9,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -158,8 +158,8 @@ type TwoPhaseStoreMeshManager struct {
IfName string IfName string
Client *wgctrl.Client Client *wgctrl.Client
LastClock uint64 LastClock uint64
conf *conf.WgConfiguration Conf *conf.WgConfiguration
daemonConf *conf.DaemonConfiguration DaemonConf *conf.DaemonConfiguration
store *TwoPhaseMap[string, MeshNode] store *TwoPhaseMap[string, MeshNode]
} }
@ -204,7 +204,6 @@ func (m *TwoPhaseStoreMeshManager) Save() []byte {
var buf bytes.Buffer var buf bytes.Buffer
enc := gob.NewEncoder(&buf) enc := gob.NewEncoder(&buf)
err := enc.Encode(*snapshot) err := enc.Encode(*snapshot)
if err != nil { if err != nil {
@ -265,7 +264,7 @@ func (m *TwoPhaseStoreMeshManager) UpdateTimeStamp(nodeId string) error {
peerToUpdate := peers[0] peerToUpdate := peers[0]
if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.daemonConf.KeepAliveTime) { if uint64(time.Now().Unix())-m.store.Clock.GetTimestamp(peerToUpdate) > 3*uint64(m.DaemonConf.HeartBeat) {
m.store.Mark(peerToUpdate) m.store.Mark(peerToUpdate)
if len(peers) < 2 { if len(peers) < 2 {
@ -411,6 +410,11 @@ func (m *TwoPhaseStoreMeshManager) RemoveService(nodeId string, key string) erro
} }
node := m.store.Get(nodeId) node := m.store.Get(nodeId)
if _, ok := node.Services[key]; !ok {
return fmt.Errorf("datastore: node does not contain service %s", key)
}
delete(node.Services, key) delete(node.Services, key)
m.store.Put(nodeId, node) m.store.Put(nodeId, node)
return nil return nil
@ -510,5 +514,5 @@ func (m *TwoPhaseStoreMeshManager) RemoveNode(nodeId string) error {
// GetConfiguration implements mesh.MeshProvider. // GetConfiguration implements mesh.MeshProvider.
func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration { func (m *TwoPhaseStoreMeshManager) GetConfiguration() *conf.WgConfiguration {
return m.conf return m.Conf
} }

442
pkg/crdt/datastore_test.go Normal file
View File

@ -0,0 +1,442 @@
package crdt
import (
"net"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type TestParams struct {
manager mesh.MeshProvider
publicKey *wgtypes.Key
}
func setUpTests() *TestParams {
advertiseRoutes := false
advertiseDefaultRoute := false
role := conf.PEER_ROLE
discovery := conf.DNS_IP_DISCOVERY
factory := &TwoPhaseMapFactory{
Config: &conf.DaemonConfiguration{
CertificatePath: "/somecertificatepath",
PrivateKeyPath: "/someprivatekeypath",
CaCertificatePath: "/somecacertificatepath",
SkipCertVerification: true,
GrpcPort: 0,
Timeout: 20,
Profile: false,
SyncTime: 2,
HeartBeat: 10,
ClusterSize: 32,
InterClusterChance: 0.15,
BranchRate: 3,
InfectionCount: 3,
BaseConfiguration: conf.WgConfiguration{
IPDiscovery: &discovery,
AdvertiseRoutes: &advertiseRoutes,
AdvertiseDefaultRoute: &advertiseDefaultRoute,
Role: &role,
},
},
}
key, _ := wgtypes.GeneratePrivateKey()
mesh, _ := factory.CreateMesh(&mesh.MeshProviderFactoryParams{
DevName: "bob",
MeshId: "meshid123",
Client: nil,
Conf: &factory.Config.BaseConfiguration,
DaemonConf: factory.Config,
NodeID: "bob",
})
publicKey := key.PublicKey()
return &TestParams{
manager: mesh,
publicKey: &publicKey,
}
}
func getOurNode(testParams *TestParams) *MeshNode {
return &MeshNode{
HostEndpoint: "public-endpoint:8080",
WgEndpoint: "public-endpoint:21906",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d218/128",
PublicKey: testParams.publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func getRandomNode() *MeshNode {
key, _ := wgtypes.GeneratePrivateKey()
publicKey := key.PublicKey()
return &MeshNode{
HostEndpoint: "public-endpoint:8081",
WgEndpoint: "public-endpoint:21907",
WgHost: "3e9a:1fb3:5e50:8173:9690:f917:b1ab:d234/128",
PublicKey: publicKey.String(),
Timestamp: time.Now().Unix(),
Description: "A node that we are adding",
Type: "peer",
}
}
func TestAddNodeAddsTheNodesToTheStore(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if !testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`node %s should have been added to the mesh network`, testParams.publicKey.String())
}
}
func TestAddNodeNodeAlreadyExistsReplacesTheNode(t *testing.T) {
TestAddNodeAddsTheNodesToTheStore(t)
TestAddNodeAddsTheNodesToTheStore(t)
}
func TestSaveThenLoad(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
bytes := testParams.manager.Save()
if err := testParams.manager.Load(bytes); err != nil {
t.Fatalf(`error caused by loading datastore: %s`, err.Error())
}
}
func TestHasChangesReturnsTrueWhenThereAreChangesInTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
if !testParams.manager.HasChanges() {
t.Fatalf(`mesh has change but HasChanges returned false`)
}
testParams.manager.SaveChanges()
}
func TestHasChangesWhenThereAreNoChangesInTheMeshReturnsFalse(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.AddNode(getRandomNode())
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
testParams.manager.SetDescription(testParams.publicKey.String(), "Bob marley")
testParams.manager.SaveChanges()
if testParams.manager.HasChanges() {
t.Fatalf(`mesh has no changes but HasChanges was true`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsTheLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
testParams.manager.UpdateTimeStamp(testParams.publicKey.String())
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() >= after.GetTimeStamp() {
t.Fatalf(`before should not be after after`)
}
}
func TestUpdateTimeStampUpdatesTheTimeStampOfTheGivenNodeIfItIsNotLeader(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
newNode := getRandomNode()
newNode.PublicKey = "aaaaaaaaaa"
testParams.manager.AddNode(newNode)
before, _ := testParams.manager.GetNode(testParams.publicKey.String())
time.Sleep(1 * time.Second)
after, _ := testParams.manager.GetNode(testParams.publicKey.String())
if before.GetTimeStamp() != after.GetTimeStamp() {
t.Fatalf(`before and after should be the same`)
}
}
func TestAddRoutesAddsARouteToTheGivenMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
testParams.manager.AddRoutes(testParams.publicKey.String(), &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
})
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if !containsDestination {
t.Fatalf(`route has not been added to the node`)
}
}
func TestRemoveRoutesWithdrawsRoutesFromTheMesh(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
_, destination, _ := net.ParseCIDR("0353:1da7:7f33:acc0:7a3f:6e55:912b:bc1f/64")
route := &mesh.RouteStub{
Destination: destination,
HopCount: 0,
Path: make([]string, 0),
}
testParams.manager.AddRoutes(testParams.publicKey.String(), route)
testParams.manager.RemoveRoutes(testParams.publicKey.String(), route)
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
containsDestination := lib.Contains(node.GetRoutes(), func(r mesh.Route) bool {
return r.GetDestination().Contains(destination.IP)
})
if containsDestination {
t.Fatalf(`route has not been removed from the node`)
}
}
func TestGetNodeGetsTheNodeWhenItExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node == nil {
t.Fatalf(`node not found returned nil`)
}
}
func TestGetNodeReturnsNilWhenItDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
if node != nil {
t.Fatalf(`node found but should be nil`)
}
}
func TestNodeExistsReturnsFalseWhenNotExists(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
testParams.manager.RemoveNode(testParams.publicKey.String())
if testParams.manager.NodeExists(testParams.publicKey.String()) {
t.Fatalf(`nodeexists should be false`)
}
}
func TestSetDescriptionReturnsErrorWhenNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetDescription("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetDescriptionSetsTheDescription(t *testing.T) {
testParams := setUpTests()
descriptionToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetDescription(testParams.publicKey.String(), descriptionToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
description := node.GetDescription()
if description != descriptionToSet {
t.Fatalf(`description was %s should be %s`, description, descriptionToSet)
}
}
func TestAliasNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.SetAlias("djdjdj", "djdsjkd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestSetAliasSetsAlias(t *testing.T) {
testParams := setUpTests()
aliasToSet := "djdsjkd"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.SetAlias(testParams.publicKey.String(), aliasToSet)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
alias := node.GetAlias()
if alias != aliasToSet {
t.Fatalf(`description was %s should be %s`, alias, aliasToSet)
}
}
func TestAddServiceNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.AddService("djdjdj", "djdsjkd", "sddsds")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestAddServiceNodeExists(t *testing.T) {
testParams := setUpTests()
service := "djdsjkd"
serviceValue := "dsdsds"
testParams.manager.AddNode(getOurNode(testParams))
err := testParams.manager.AddService(testParams.publicKey.String(), service, serviceValue)
if err != nil {
t.Fatalf(`error %s thrown`, err.Error())
}
node, _ := testParams.manager.GetNode(testParams.publicKey.String())
services := node.GetServices()
if value, ok := services[service]; !ok || value != serviceValue {
t.Fatalf(`service not added to the data store`)
}
}
func TestRemoveServiceDoesNotExists(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveService("djdjdj", "dsdssd")
if err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestRemoveServiceServiceDoesNotExist(t *testing.T) {
testParams := setUpTests()
testParams.manager.AddNode(getOurNode(testParams))
if err := testParams.manager.RemoveService(testParams.publicKey.String(), "dhsdh"); err == nil {
t.Fatalf(`error should be thrown`)
}
}
func TestGetPeersReturnsAllPeersInTheMesh(t *testing.T) {
testParams := setUpTests()
peer1 := getRandomNode()
peer2 := getRandomNode()
client := getRandomNode()
client.Type = "client"
testParams.manager.AddNode(peer1)
testParams.manager.AddNode(peer2)
testParams.manager.AddNode(client)
peers := testParams.manager.GetPeers()
slices.Sort(peers)
if len(peers) != 2 {
t.Fatalf(`there should be two peers in the mesh`)
}
peer1Pub, _ := peer1.GetPublicKey()
if !slices.Contains(peers, peer1Pub.String()) {
t.Fatalf(`peer1 not in the list`)
}
peer2Pub, _ := peer2.GetPublicKey()
if !slices.Contains(peers, peer2Pub.String()) {
t.Fatalf(`peer2 not in the list`)
}
}
func TestRemoveNodeReturnsErrorIfNodeDoesNotExist(t *testing.T) {
testParams := setUpTests()
err := testParams.manager.RemoveNode("dsjdssjk")
if err == nil {
t.Fatalf(`error should have returned`)
}
}

View File

@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"hash/fnv" "hash/fnv"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
type TwoPhaseMapFactory struct { type TwoPhaseMapFactory struct {
@ -18,13 +18,13 @@ func (f *TwoPhaseMapFactory) CreateMesh(params *mesh.MeshProviderFactoryParams)
MeshId: params.MeshId, MeshId: params.MeshId,
IfName: params.DevName, IfName: params.DevName,
Client: params.Client, Client: params.Client,
conf: params.Conf, Conf: params.Conf,
daemonConf: params.DaemonConf, DaemonConf: params.DaemonConf,
store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 { store: NewTwoPhaseMap[string, MeshNode](params.NodeID, func(s string) uint64 {
h := fnv.New64a() h := fnv.New64a()
h.Write([]byte(s)) h.Write([]byte(s))
return h.Sum64() return h.Sum64()
}, uint64(3*f.Config.KeepAliveTime)), }, uint64(3*f.Config.HeartBeat)),
}, nil }, nil
} }

View File

@ -1,4 +1,4 @@
// crdt is a golang implementation of a crdt // crdt provides go implementations for crdts
package crdt package crdt
import ( import (
@ -65,10 +65,19 @@ func (g *GMap[K, D]) get(key uint64) Bucket[D] {
} }
func (g *GMap[K, D]) Get(key K) D { func (g *GMap[K, D]) Get(key K) D {
if !g.Contains(key) {
var def D
return def
}
return g.get(g.clock.hashFunc(key)).Contents return g.get(g.clock.hashFunc(key)).Contents
} }
func (g *GMap[K, D]) Mark(key K) { func (g *GMap[K, D]) Mark(key K) {
if !g.Contains(key) {
return
}
g.lock.Lock() g.lock.Lock()
bucket := g.contents[g.clock.hashFunc(key)] bucket := g.contents[g.clock.hashFunc(key)]
bucket.Gravestone = true bucket.Gravestone = true
@ -89,7 +98,6 @@ func (g *GMap[K, D]) IsMarked(key K) bool {
} }
g.lock.RUnlock() g.lock.RUnlock()
return marked return marked
} }

224
pkg/crdt/g_map_test.go Normal file
View File

@ -0,0 +1,224 @@
// crdt_test unit tests the crdt implementations
package crdt
import (
"hash/fnv"
"slices"
"testing"
"time"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
func NewGmap() *GMap[string, bool] {
vectorClock := NewVectorClock("a", func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1) // 1 second stale time
gMap := NewGMap[string, bool](vectorClock)
return gMap
}
func TestGMapPutInsertsItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
if !gMap.Contains("bruh1234") {
t.Fatalf(`value not added to map`)
}
}
func TestGMapPutReplacesItems(t *testing.T) {
gMap := NewGmap()
gMap.Put("bruh1234", true)
gMap.Put("bruh1234", false)
value := gMap.Get("bruh1234")
if value {
t.Fatalf(`value should ahve been replaced to false`)
}
}
func TestContainsValueNotPresent(t *testing.T) {
gMap := NewGmap()
if gMap.Contains("sdhjsdhsdj") {
t.Fatalf(`value should not be present in the map`)
}
}
func TestContainsValuePresent(t *testing.T) {
gMap := NewGmap()
key := "hehehehe"
gMap.Put(key, false)
if !gMap.Contains(key) {
t.Fatalf(`%s should not be present in the map`, key)
}
}
func TestGMapGetNotPresentReturnsError(t *testing.T) {
gMap := NewGmap()
value := gMap.Get("bruh123")
if value != false {
t.Fatalf(`value should be default type false`)
}
}
func TestGMapGetReturnsValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("bobdylan", true)
value := gMap.Get("bobdylan")
if !value {
t.Fatalf("value should be true but was false")
}
}
func TestMarkMarksTheValue(t *testing.T) {
gMap := NewGmap()
gMap.Put("hello123", true)
gMap.Mark("hello123")
if !gMap.IsMarked("hello123") {
t.Fatal(`hello123 should be marked`)
}
}
func TestMarkValueNotPresent(t *testing.T) {
gMap := NewGmap()
gMap.Mark("ok123456")
}
func TestKeysMapEmpty(t *testing.T) {
gMap := NewGmap()
keys := gMap.Keys()
if len(keys) != 0 {
t.Fatal(`list of keys was not empty but should be empty`)
}
}
func TestKeysMapReturnsKeysInMap(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
keys := gMap.Keys()
if len(keys) != 3 {
t.Fatal(`key length should be 3`)
}
}
func TestSaveMapEmptyReturnsEmptyMap(t *testing.T) {
gMap := NewGmap()
saveMap := gMap.Save()
if len(saveMap) != 0 {
t.Fatal(`saves should be empty`)
}
}
func TestSaveMapReturnsMapOfBuckets(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.Save()
if len(saveMap) != 3 {
t.Fatalf(`save length should be 3`)
}
}
func TestSaveWithKeysNoKeysReturnsEmptyBucket(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
saveMap := gMap.SaveWithKeys([]uint64{})
if len(saveMap) != 0 {
t.Fatalf(`save map should be empty`)
}
}
func TestSaveWithKeysReturnsIntersection(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clock := lib.MapKeys(gMap.GetClock())
clock = clock[:len(clock)-1]
values := gMap.SaveWithKeys(clock)
if len(values) != len(clock) {
t.Fatalf(`intersection not returned`)
}
}
func TestGetClockMapEmptyReturnsEmptyClock(t *testing.T) {
gMap := NewGmap()
clocks := gMap.GetClock()
if len(clocks) != 0 {
t.Fatalf(`vector clock is not empty`)
}
}
func TestGetClockReturnsAllCLocks(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
gMap.Put("b", false)
gMap.Put("c", false)
clocks := lib.MapValues(gMap.GetClock())
slices.Sort(clocks)
if !slices.Equal([]uint64{0, 1, 2}, clocks) {
t.Fatalf(`clocks are invalid`)
}
}
func TestGetHashChangesHashOnValueAdded(t *testing.T) {
gMap := NewGmap()
gMap.Put("a", false)
prevHash := gMap.GetHash()
gMap.Put("b", true)
if prevHash == gMap.GetHash() {
t.Fatalf(`hash should be different`)
}
}
func TestPruneGarbageCollectsValuesThatHaveNotBeenUpdated(t *testing.T) {
gMap := NewGmap()
gMap.clock.Put("c", 12)
gMap.Put("c", false)
gMap.Put("a", false)
time.Sleep(4 * time.Second)
gMap.Put("a", true)
gMap.Prune()
if gMap.Contains("c") {
t.Fatalf(`a should have been pruned`)
}
}

View File

@ -3,7 +3,7 @@ package crdt
import ( import (
"cmp" "cmp"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
type TwoPhaseMap[K cmp.Ordered, D any] struct { type TwoPhaseMap[K cmp.Ordered, D any] struct {

View File

@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
type SyncState int type SyncState int

View File

@ -0,0 +1,214 @@
package crdt
import (
"hash/fnv"
"slices"
"testing"
)
func NewMap(processId string) *TwoPhaseMap[string, string] {
theMap := NewTwoPhaseMap[string, string](processId, func(key string) uint64 {
hash := fnv.New64a()
hash.Write([]byte(key))
return hash.Sum64()
}, 1)
return theMap
}
func TestTwoPhaseMapEmpty(t *testing.T) {
theMap := NewMap("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapValuePresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`should be present within the map`)
}
}
func TestTwoPhaseMapValueNotPresent(t *testing.T) {
theMap := NewMap("a")
theMap.Put("b", "")
if theMap.Contains("a") {
t.Fatalf(`a should not be present in the map`)
}
}
func TestTwoPhaseMapPutThenRemove(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
if theMap.Contains("a") {
t.Fatalf(`a should not be present within the map`)
}
}
func TestTwoPhaseMapPutThenRemoveThenPut(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Remove("a")
theMap.Put("a", "")
if !theMap.Contains("a") {
t.Fatalf(`a should be present within the map`)
}
}
func TestMarkMarksTheValueIn2PMap(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "")
theMap.Mark("a")
if !theMap.IsMarked("a") {
t.Fatalf(`a should be marked`)
}
}
func TestAsListReturnsItemsInList(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
keys := theMap.AsList()
slices.Sort(keys)
if !slices.Equal([]string{"bob", "dylan"}, keys) {
t.Fatalf(`values should be bob, dylan`)
}
}
func TestSnapShotRemoveMapEmpty(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "bob")
theMap.Put("b", "dylan")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 2 {
t.Fatalf(`add values length should be 2`)
}
if len(snapshot.Remove) != 0 {
t.Fatalf(`remove map length should be 0`)
}
}
func TestSnapshotMapEmpty(t *testing.T) {
theMap := NewMap("a")
snapshot := theMap.Snapshot()
if len(snapshot.Add) != 0 || len(snapshot.Remove) != 0 {
t.Fatalf(`snapshot length should be 0`)
}
}
func TestSnapShotFromStateReturnsIntersection(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "heyy")
map2 := NewMap("b")
map2.Put("b", "hmmm")
message := map2.GenerateMessage()
snapShot := map1.SnapShotFromState(message)
if len(snapShot.Add) != 1 {
t.Fatalf(`add length should be 1`)
}
if len(snapShot.Remove) != 0 {
t.Fatalf(`remove length should be 0`)
}
}
func TestGetHashDifferentOnChange(t *testing.T) {
theMap := NewMap("a")
prevHash := theMap.GetHash()
theMap.Put("b", "hmmhmhmh")
if prevHash == theMap.GetHash() {
t.Fatalf(`hashes should not be the same`)
}
}
func TestGenerateMessageReturnsClocks(t *testing.T) {
theMap := NewMap("a")
theMap.Put("a", "hmm")
theMap.Put("b", "hmm")
theMap.Remove("a")
message := theMap.GenerateMessage()
if len(message.AddContents) != 2 {
t.Fatalf(`two items added add should be 2`)
}
if len(message.RemoveContents) != 1 {
t.Fatalf(`a was removed remove map should be length 1`)
}
}
func TestDifferenceReturnsDifferenceOfMaps(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
if len(difference.AddContents) != 2 {
t.Fatalf(`d and c are not in map1 they should be in add contents`)
}
if len(difference.RemoveContents) != 0 {
t.Fatalf(`remove should be empty`)
}
}
func TestMergeMergesValuesThatAreGreaterThanCurrentClock(t *testing.T) {
map1 := NewMap("a")
map1.Put("a", "ssms")
map1.Put("b", "sdmdsmd")
map2 := NewMap("b")
map2.Put("d", "eek")
map2.Put("c", "meh")
message1 := map1.GenerateMessage()
message2 := map2.GenerateMessage()
difference := message1.Difference(0, message2)
state := map2.SnapShotFromState(difference)
map1.Merge(*state)
if !map1.Contains("d") {
t.Fatalf(`d should be in the map`)
}
if !map2.Contains("c") {
t.Fatalf(`c should be in the map`)
}
}

View File

@ -5,7 +5,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
type VectorBucket struct { type VectorBucket struct {

View File

@ -1,16 +1,16 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/crdt" "github.com/tim-beatham/smegmesh/pkg/crdt"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@ -34,7 +34,7 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := &crdt.MeshNodeFactory{ nodeFactory := &crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
idGenerator := &lib.IDNameGenerator{} idGenerator := &lib.ShortIDGenerator{}
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
@ -89,7 +89,6 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
return ctrlServer, nil return ctrlServer, nil
} }
func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration { func (s *MeshCtrlServer) GetConfiguration() *conf.DaemonConfiguration {
return s.Conf return s.Conf
} }

View File

@ -4,11 +4,11 @@ import (
"net" "net"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )

View File

@ -1,10 +1,10 @@
package ctrlserver package ctrlserver
import ( import (
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )

View File

@ -4,21 +4,18 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net" "net"
"net/rpc"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/smegmesh/pkg/query"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.` const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct { type DNSHandler struct {
client *rpc.Client client *ipc.SmegmeshIpc
server *dns.Server server *dns.Server
} }
@ -27,7 +24,7 @@ type DNSHandler struct {
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP { func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{ err := d.client.Query(ipc.QueryMesh{
MeshId: meshId, MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias), Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply) }, &reply)
@ -97,7 +94,7 @@ func (h *DNSHandler) Close() error {
} }
func NewDns(udpPort int) (*DNSHandler, error) { func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr) client, err := ipc.NewClientIpc()
if err != nil { if err != nil {
return nil, err return nil, err

227
pkg/dot/dot.go Normal file
View File

@ -0,0 +1,227 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/smegmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH GraphType = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
PARALLELOGRAM Shape = "parallelogram"
)
type Graph interface {
Dottable
GetType() GraphType
}
type Cluster struct {
Type GraphType
Name string
Label string
nodes map[string]*Node
edges map[string]Edge
}
type RootGraph struct {
Type GraphType
Label string
nodes map[string]*Node
clusters map[string]*Cluster
edges map[string]Edge
}
type Node struct {
Name string
Label string
Shape Shape
Size int
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Name string
Label string
From string
To string
}
type UndirectedEdge struct {
Name string
Label string
From string
To string
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *RootGraph {
return &RootGraph{Type: graphType, Label: label, clusters: map[string]*Cluster{}, nodes: make(map[string]*Node), edges: make(map[string]Edge)}
}
// PutNode: puts a node in the graph
func (g *RootGraph) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Size: size, Shape: shape}
return nil
}
func (g *RootGraph) PutCluster(graph *Cluster) {
g.clusters[graph.Label] = graph
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *RootGraph) GetDOT() (string, error) {
var result strings.Builder
result.WriteString(fmt.Sprintf("%s {\n", g.Type))
result.WriteString("node [colorscheme=set312];\n")
result.WriteString("layout = fdp;\n")
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&result, nodes...)
writeContituents(&result, edges...)
for _, cluster := range g.clusters {
clusterDOT, err := cluster.GetDOT()
if err != nil {
return "", err
}
result.WriteString(clusterDOT)
}
result.WriteString("}")
return result.String(), nil
}
// GetType implements Graph.
func (r *RootGraph) GetType() GraphType {
return r.Type
}
func constructEdge(graph Graph, name, label, from, to string) Edge {
switch graph.GetType() {
case DIGRAPH:
return &DirectedEdge{Name: name, Label: label, From: from, To: to}
default:
return &UndirectedEdge{Name: name, Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *RootGraph) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[label=\"%s\",shape=%s, style=\"filled\", fillcolor=%d, width=%d, height=%d, fixedsize=true] \"%s\";\n",
n.Label, n.Shape, n.hash(), n.Size, n.Size, n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -> \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("\"%s\" -- \"%s\" [label=\"%s\"];\n", e.From, e.To, e.Label), nil
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Cluster) AddEdge(name string, label string, from string, to string) error {
g.edges[name] = constructEdge(g, name, label, from, to)
return nil
}
// PutNode: puts a node in the graph
func (g *Cluster) PutNode(name, label string, size int, shape Shape) error {
_, exists := g.nodes[name]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[name] = &Node{Name: name, Label: label, Shape: shape, Size: size}
return nil
}
func (g *Cluster) GetDOT() (string, error) {
var builder strings.Builder
builder.WriteString(fmt.Sprintf("subgraph \"cluster%s\" {\n", g.Label))
builder.WriteString(fmt.Sprintf("label = \"%s\"\n", g.Label))
nodes := lib.MapValues(g.nodes)
edges := lib.MapValues(g.edges)
writeContituents(&builder, nodes...)
writeContituents(&builder, edges...)
builder.WriteString("}\n")
return builder.String(), nil
}
func (g *Cluster) GetType() GraphType {
return g.Type
}
func NewSubGraph(name string, label string, graphType GraphType) *Cluster {
return &Cluster{
Label: name,
Type: graphType,
Name: name,
nodes: make(map[string]*Node),
edges: make(map[string]Edge),
}
}

116
pkg/dot/wg.go Normal file
View File

@ -0,0 +1,116 @@
package graph
import (
"fmt"
"slices"
"github.com/tim-beatham/smegmesh/pkg/ctrlserver"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate() (string, error)
}
type MeshDOTConverter struct {
meshes map[string][]ctrlserver.MeshNode
destinations map[string]interface{}
}
func (c *MeshDOTConverter) Generate() (string, error) {
g := NewGraph("Smegmesh", GRAPH)
for meshId := range c.meshes {
err := c.generateMesh(g, meshId)
if err != nil {
return "", err
}
}
for mesh := range c.meshes {
g.PutNode(mesh, mesh, 1, CIRCLE)
}
for destination := range c.destinations {
g.PutNode(destination, destination, 1, HEXAGON)
}
return g.GetDOT()
}
func (c *MeshDOTConverter) generateMesh(g *RootGraph, meshId string) error {
nodes := c.meshes[meshId]
g.PutNode(meshId, meshId, 1, CIRCLE)
for _, node := range nodes {
c.graphNode(g, node, meshId)
}
for _, node := range nodes {
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, meshId), "", node.PublicKey, meshId)
}
return nil
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *RootGraph, node ctrlserver.MeshNode, meshId string) {
alias := node.Alias
if alias == "" {
alias = node.WgHost[1:len(node.WgHost)-20] + "\\n" + node.WgHost[len(node.WgHost)-20:len(node.WgHost)]
}
g.PutNode(node.PublicKey, alias, 2, CIRCLE)
for _, route := range node.Routes {
if len(route.Path) == 0 {
g.AddEdge(route.Destination, "", node.PublicKey, route.Destination)
continue
}
reversedPath := slices.Clone(route.Path)
slices.Reverse(reversedPath)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, reversedPath[0]), "", node.PublicKey, reversedPath[0])
for _, mesh := range route.Path {
if _, ok := c.meshes[mesh]; !ok {
c.destinations[mesh] = struct{}{}
}
}
for index := range reversedPath[0 : len(reversedPath)-1] {
routeID := fmt.Sprintf("%s to %s", reversedPath[index], reversedPath[index+1])
g.AddEdge(routeID, "", reversedPath[index], reversedPath[index+1])
}
if route.Destination == "::/0" {
c.destinations[route.Destination] = struct{}{}
lastMesh := reversedPath[len(reversedPath)-1]
routeID := fmt.Sprintf("%s to %s", lastMesh, route.Destination)
g.AddEdge(routeID, "", lastMesh, route.Destination)
}
}
for service := range node.Services {
c.putService(g, service, meshId, node)
}
}
// putService: construct a service node and a link between the nodes
func (c *MeshDOTConverter) putService(g *RootGraph, key, meshId string, node ctrlserver.MeshNode) {
serviceID := fmt.Sprintf("%s%s%s", key, node.PublicKey, meshId)
g.PutNode(serviceID, key, 1, PARALLELOGRAM)
g.AddEdge(fmt.Sprintf("%s to %s", node.PublicKey, serviceID), "", node.PublicKey, serviceID)
}
func NewMeshGraphConverter(meshes map[string][]ctrlserver.MeshNode) MeshGraphConverter {
return &MeshDOTConverter{
meshes: meshes,
destinations: make(map[string]interface{}),
}
}

View File

@ -1,178 +0,0 @@
// Graph allows the definition of a DOT graph in golang
package graph
import (
"errors"
"fmt"
"hash/fnv"
"strings"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
type GraphType string
type Shape string
const (
GRAPH GraphType = "graph"
DIGRAPH = "digraph"
)
const (
CIRCLE Shape = "circle"
STAR Shape = "star"
HEXAGON Shape = "hexagon"
)
type Graph struct {
Type GraphType
Label string
nodes map[string]*Node
edges []Edge
}
type Node struct {
Name string
Shape Shape
}
type Edge interface {
Dottable
}
type DirectedEdge struct {
Label string
From *Node
To *Node
}
type UndirectedEdge struct {
Label string
From *Node
To *Node
}
// Dottable means an implementer can convert the struct to DOT representation
type Dottable interface {
GetDOT() (string, error)
}
func NewGraph(label string, graphType GraphType) *Graph {
return &Graph{Type: graphType, Label: label, nodes: make(map[string]*Node), edges: make([]Edge, 0)}
}
// PutNode: puts a node in the graph
func (g *Graph) PutNode(label string, shape Shape) error {
_, exists := g.nodes[label]
if exists {
// If exists no need to add the ndoe
return nil
}
g.nodes[label] = &Node{Name: label, Shape: shape}
return nil
}
func writeContituents[D Dottable](result *strings.Builder, elements ...D) error {
for _, node := range elements {
dot, err := node.GetDOT()
if err != nil {
return err
}
_, err = result.WriteString(dot)
if err != nil {
return err
}
}
return nil
}
func (g *Graph) GetDOT() (string, error) {
var result strings.Builder
_, err := result.WriteString(fmt.Sprintf("%s {\n", g.Type))
if err != nil {
return "", err
}
_, err = result.WriteString("node [colorscheme=set312];\n")
if err != nil {
return "", err
}
nodes := lib.MapValues(g.nodes)
err = writeContituents(&result, nodes...)
if err != nil {
return "", err
}
err = writeContituents(&result, g.edges...)
if err != nil {
return "", err
}
_, err = result.WriteString("}")
if err != nil {
return "", err
}
return result.String(), nil
}
func (g *Graph) constructEdge(label string, from *Node, to *Node) Edge {
switch g.Type {
case DIGRAPH:
return &DirectedEdge{Label: label, From: from, To: to}
default:
return &UndirectedEdge{Label: label, From: from, To: to}
}
}
// AddEdge: adds an edge between two nodes in the graph
func (g *Graph) AddEdge(label string, from string, to string) error {
fromNode, exists := g.nodes[from]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", from))
}
toNode, exists := g.nodes[to]
if !exists {
return errors.New(fmt.Sprintf("Node %s does not exist", to))
}
g.edges = append(g.edges, g.constructEdge(label, fromNode, toNode))
return nil
}
const numColours = 12
func (n *Node) hash() int {
h := fnv.New32a()
h.Write([]byte(n.Name))
return (int(h.Sum32()) % numColours) + 1
}
func (n *Node) GetDOT() (string, error) {
return fmt.Sprintf("node[shape=%s, style=\"filled\", fillcolor=%d] %s;\n",
n.Shape, n.hash(), n.Name), nil
}
func (e *DirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -> %s;\n", e.From.Name, e.To.Name), nil
}
func (e *UndirectedEdge) GetDOT() (string, error) {
return fmt.Sprintf("%s -- %s;\n", e.From.Name, e.To.Name), nil
}

View File

@ -34,7 +34,7 @@ type CgaParameters struct {
flag byte flag byte
} }
func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) { func NewCga(key wgtypes.Key, collisionCount uint8, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParameters, error) {
var params CgaParameters var params CgaParameters
_, err := rand.Read(params.Modifier[:]) _, err := rand.Read(params.Modifier[:])
@ -45,6 +45,7 @@ func NewCga(key wgtypes.Key, subnetPrefix [2 * InterfaceIdLen]byte) (*CgaParamet
params.PublicKey = key params.PublicKey = key
params.SubnetPrefix = subnetPrefix params.SubnetPrefix = subnetPrefix
params.CollisionCount = collisionCount
return &params, nil return &params, nil
} }
@ -78,7 +79,6 @@ func (c *CgaParameters) generateHash1() []byte {
byteVal[hash1Length-1] = c.CollisionCount byteVal[hash1Length-1] = c.CollisionCount
hash := sha1.Sum(byteVal[:]) hash := sha1.Sum(byteVal[:])
return hash[:Hash1Prefix] return hash[:Hash1Prefix]
} }
@ -90,9 +90,6 @@ func clearBit(num, pos int) byte {
} }
func (c *CgaParameters) generateInterface() []byte { func (c *CgaParameters) generateInterface() []byte {
// TODO: On duplicate address detection increment collision.
// Also incorporate SEC
hash1 := c.generateHash1() hash1 := c.generateHash1()
var interfaceId []byte = make([]byte, InterfaceIdLen) var interfaceId []byte = make([]byte, InterfaceIdLen)

View File

@ -7,5 +7,5 @@ import (
) )
type IPAllocator interface { type IPAllocator interface {
GetIP(key wgtypes.Key, meshId string) (net.IP, error) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error)
} }

View File

@ -39,10 +39,10 @@ func (u *ULABuilder) GetIPNet(meshId string) (*net.IPNet, error) {
return net, nil return net, nil
} }
func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string) (net.IP, error) { func (u *ULABuilder) GetIP(key wgtypes.Key, meshId string, collisionCount uint8) (net.IP, error) {
ulaPrefix := getMeshPrefix(meshId) ulaPrefix := getMeshPrefix(meshId)
c, err := NewCga(key, ulaPrefix) c, err := NewCga(key, collisionCount, ulaPrefix)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -5,11 +5,27 @@ import (
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
ipcRpc "net/rpc"
"os" "os"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
) )
const SockAddr = "/tmp/wgmesh_sock"
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args *JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
Query(query QueryMesh, reply *string) error
PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error
DeleteService(args DeleteServiceArgs, reply *string) error
}
// WireGuardArgs are provided args specific to WireGuard // WireGuardArgs are provided args specific to WireGuard
type WireGuardArgs struct { type WireGuardArgs struct {
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
@ -39,7 +55,7 @@ type JoinMeshArgs struct {
// MeshId is the ID of the mesh to join // MeshId is the ID of the mesh to join
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAddress string
// WgArgs is the WireGuard parameters to use. // WgArgs is the WireGuard parameters to use.
WgArgs WireGuardArgs WgArgs WireGuardArgs
} }
@ -47,6 +63,22 @@ type JoinMeshArgs struct {
type PutServiceArgs struct { type PutServiceArgs struct {
Service string Service string
Value string Value string
MeshId string
}
type DeleteServiceArgs struct {
Service string
MeshId string
}
type PutAliasArgs struct {
Alias string
MeshId string
}
type PutDescriptionArgs struct {
Description string
MeshId string
} }
type GetMeshReply struct { type GetMeshReply struct {
@ -62,21 +94,78 @@ type QueryMesh struct {
Query string Query string
} }
type MeshIpc interface { type ClientIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error ListMeshes(args *ListMeshReply, reply *string) error
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error PutDescription(args PutDescriptionArgs, reply *string) error
PutAlias(alias string, reply *string) error PutAlias(args PutAliasArgs, reply *string) error
PutService(args PutServiceArgs, reply *string) error PutService(args PutServiceArgs, reply *string) error
DeleteService(service string, reply *string) error DeleteService(args DeleteServiceArgs, reply *string) error
} }
const SockAddr = "/tmp/wgmesh_ipc.sock" type SmegmeshIpc struct {
client *ipcRpc.Client
}
func NewClientIpc() (*SmegmeshIpc, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
return &SmegmeshIpc{
client: client,
}, nil
}
func (c *SmegmeshIpc) CreateMesh(args *NewMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.CreateMesh", args, reply)
}
func (c *SmegmeshIpc) ListMeshes(reply *ListMeshReply) error {
return c.client.Call("IpcHandler.ListMeshes", "", reply)
}
func (c *SmegmeshIpc) JoinMesh(args JoinMeshArgs, reply *string) error {
return c.client.Call("IpcHandler.JoinMesh", &args, reply)
}
func (c *SmegmeshIpc) LeaveMesh(meshId string, reply *string) error {
return c.client.Call("IpcHandler.LeaveMesh", &meshId, reply)
}
func (c *SmegmeshIpc) GetMesh(meshId string, reply *GetMeshReply) error {
return c.client.Call("IpcHandler.GetMesh", &meshId, reply)
}
func (c *SmegmeshIpc) Query(query QueryMesh, reply *string) error {
return c.client.Call("IpcHandler.Query", &query, reply)
}
func (c *SmegmeshIpc) PutDescription(args PutDescriptionArgs, reply *string) error {
return c.client.Call("IpcHandler.PutDescription", &args, reply)
}
func (c *SmegmeshIpc) PutAlias(args PutAliasArgs, reply *string) error {
return c.client.Call("IpcHandler.PutAlias", &args, reply)
}
func (c *SmegmeshIpc) PutService(args PutServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.PutService", &args, reply)
}
func (c *SmegmeshIpc) DeleteService(args DeleteServiceArgs, reply *string) error {
return c.client.Call("IpcHandler.DeleteService", &args, reply)
}
func (c *SmegmeshIpc) Close() error {
return c.Close()
}
func RunIpcHandler(server MeshIpc) error { func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil { if err := os.RemoveAll(SockAddr); err != nil {

View File

@ -3,6 +3,7 @@ package lib
import ( import (
"github.com/anandvarma/namegen" "github.com/anandvarma/namegen"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lithammer/shortuuid"
) )
// IdGenerator generates unique ids // IdGenerator generates unique ids
@ -19,6 +20,14 @@ func (g *UUIDGenerator) GetId() (string, error) {
return id.String(), nil return id.String(), nil
} }
type ShortIDGenerator struct {
}
func (g *ShortIDGenerator) GetId() (string, error) {
id := shortuuid.New()
return id, nil
}
type IDNameGenerator struct { type IDNameGenerator struct {
} }

View File

@ -6,7 +6,7 @@ import (
"net" "net"
"github.com/jsimonetti/rtnetlink" "github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )

View File

@ -7,10 +7,10 @@ import (
"strings" "strings"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route" "github.com/tim-beatham/smegmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -91,7 +91,11 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
return p.PublicKey.String() == pubKey.String() return p.PublicKey.String() == pubKey.String()
}) })
endpoint, err := net.ResolveUDPAddr("udp", params.node.GetWgEndpoint()) var endpoint *net.UDPAddr = nil
if params.node.GetType() == conf.PEER_ROLE {
endpoint, err = net.ResolveUDPAddr("udp", params.node.GetWgEndpoint())
}
if err != nil { if err != nil {
return nil, err return nil, err
@ -115,7 +119,7 @@ func (m *WgMeshConfigApplyer) convertMeshNode(params convertMeshNodeParams) (*wg
// getRoutes: finds the routes with the least hop distance. If more than one route exists // getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic // consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode { func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) (map[string][]routeNode, error) {
mesh, _ := meshProvider.GetMesh() mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode) routes := make(map[string][]routeNode)
@ -154,17 +158,19 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
// Client's only acessible by another peer // Client's only acessible by another peer
if node.GetType() == conf.CLIENT_ROLE { if node.GetType() == conf.CLIENT_ROLE {
peer := m.getCorrespondingPeer(peers, node) peer := m.getCorrespondingPeer(peers, node)
self, _ := m.meshManager.GetSelf(meshProvider.GetMeshId()) self, err := meshProvider.GetNode(m.meshManager.GetPublicKey().String())
if err != nil {
return nil, err
}
// If the node isn't the self use that peer as the gateway
if !NodeEquals(peer, self) { if !NodeEquals(peer, self) {
peerPub, _ := peer.GetPublicKey() peerPub, _ := peer.GetPublicKey()
rn.gateway = peerPub.String() rn.gateway = peerPub.String()
rn.route = &RouteStub{ rn.route = &RouteStub{
Destination: rn.route.GetDestination(), Destination: rn.route.GetDestination(),
HopCount: rn.route.GetHopCount() + 1, HopCount: rn.route.GetHopCount() + 1,
// Append the path to this peer Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
Path: append(rn.route.GetPath(), peer.GetWgHost().IP.String()),
} }
} }
} }
@ -181,7 +187,7 @@ func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]
} }
} }
return routes return routes, nil
} }
// getCorrespondignPeer: gets the peer corresponding to the client // getCorrespondignPeer: gets the peer corresponding to the client
@ -215,7 +221,6 @@ type GetConfigParams struct {
} }
func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) { func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes.Config, error) {
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId())
ula := &ip.ULABuilder{} ula := &ip.ULABuilder{}
meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId()) meshNet, _ := ula.GetIPNet(params.mesh.GetMeshId())
@ -231,6 +236,8 @@ func (m *WgMeshConfigApplyer) getClientConfig(params *GetConfigParams) (*wgtypes
}) })
routes = append(routes, *meshNet) routes = append(routes, *meshNet)
self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -298,7 +305,7 @@ func (m *WgMeshConfigApplyer) getPeerConfig(params *GetConfigParams) (*wgtypes.C
peerToClients := make(map[string][]net.IPNet) peerToClients := make(map[string][]net.IPNet)
installedRoutes := make([]lib.Route, 0) installedRoutes := make([]lib.Route, 0)
peerConfigs := make([]wgtypes.PeerConfig, 0) peerConfigs := make([]wgtypes.PeerConfig, 0)
self, err := m.meshManager.GetSelf(params.mesh.GetMeshId()) self, err := params.mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return nil, err return nil, err
@ -389,7 +396,7 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return mn.GetType() == conf.CLIENT_ROLE return mn.GetType() == conf.CLIENT_ROLE
}) })
self, err := m.meshManager.GetSelf(mesh.GetMeshId()) self, err := mesh.GetNode(m.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return err return err
@ -428,11 +435,15 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider, routes map[string]
return nil return nil
} }
func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode { func (m *WgMeshConfigApplyer) getAllRoutes() (map[string][]routeNode, error) {
allRoutes := make(map[string][]routeNode) allRoutes := make(map[string][]routeNode)
for _, mesh := range m.meshManager.GetMeshes() { for _, mesh := range m.meshManager.GetMeshes() {
routes := m.getRoutes(mesh) routes, err := m.getRoutes(mesh)
if err != nil {
return nil, err
}
for destination, route := range routes { for destination, route := range routes {
_, ok := allRoutes[destination] _, ok := allRoutes[destination]
@ -450,11 +461,15 @@ func (m *WgMeshConfigApplyer) getAllRoutes() map[string][]routeNode {
} }
} }
return allRoutes return allRoutes, nil
} }
func (m *WgMeshConfigApplyer) ApplyConfig() error { func (m *WgMeshConfigApplyer) ApplyConfig() error {
allRoutes := m.getAllRoutes() allRoutes, err := m.getAllRoutes()
if err != nil {
return err
}
for _, mesh := range m.meshManager.GetMeshes() { for _, mesh := range m.meshManager.GetMeshes() {
err := m.updateWgConf(mesh, allRoutes) err := m.updateWgConf(mesh, allRoutes)

View File

@ -1,77 +0,0 @@
package mesh
import (
"errors"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/graph"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate(meshId string) (string, error)
}
type MeshDOTConverter struct {
manager MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
mesh := c.manager.GetMesh(meshId)
if mesh == nil {
return "", errors.New("mesh does not exist")
}
g := graph.NewGraph(meshId, graph.GRAPH)
snapshot, err := mesh.GetMesh()
if err != nil {
return "", err
}
for _, node := range snapshot.GetNodes() {
c.graphNode(g, node, meshId)
}
nodes := lib.MapValues(snapshot.GetNodes())
for i, node1 := range nodes[:len(nodes)-1] {
for _, node2 := range nodes[i+1:] {
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
continue
}
node1Id := fmt.Sprintf("\"%s\"", node1.GetIdentifier())
node2Id := fmt.Sprintf("\"%s\"", node2.GetIdentifier())
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
}
}
return g.GetDOT()
}
// graphNode: graphs a node within the mesh
func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId string) {
nodeId := fmt.Sprintf("\"%s\"", node.GetIdentifier())
g.PutNode(nodeId, graph.CIRCLE)
self, _ := c.manager.GetSelf(meshId)
if NodeEquals(self, node) {
return
}
for _, route := range node.GetRoutes() {
routeId := fmt.Sprintf("\"%s\"", route)
g.PutNode(routeId, graph.HEXAGON)
g.AddEdge(fmt.Sprintf("%s to %s", nodeId, routeId), nodeId, routeId)
}
}
func NewMeshDotConverter(m MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -3,13 +3,15 @@ package mesh
import ( import (
"errors" "errors"
"fmt" "fmt"
"net"
"sync" "sync"
"github.com/tim-beatham/wgmesh/pkg/cmd" "github.com/tim-beatham/smegmesh/pkg/cmd"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -24,10 +26,10 @@ type MeshManager interface {
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error ApplyConfig() error
SetDescription(description string) error SetDescription(meshId, description string) error
SetAlias(alias string) error SetAlias(meshId, alias string) error
SetService(service string, value string) error SetService(meshId, service, value string) error
RemoveService(service string) error RemoveService(meshId, service string) error
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
@ -61,29 +63,33 @@ func (m *MeshManagerImpl) GetRouteManager() RouteManager {
} }
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error { func (m *MeshManagerImpl) RemoveService(meshId, service string) error {
for _, mesh := range m.Meshes { mesh := m.GetMesh(meshId)
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
} }
return nil if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
} }
// SetService implements MeshManager. // SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error { func (m *MeshManagerImpl) SetService(meshId, service, value string) error {
for _, mesh := range m.Meshes { mesh := m.GetMesh(meshId)
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
} }
return nil if !mesh.NodeExists(m.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
} }
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode { func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
@ -134,6 +140,10 @@ func (m *MeshManagerImpl) CreateMesh(args *CreateMeshParams) (string, error) {
return "", err return "", err
} }
if *meshConfiguration.Role == conf.CLIENT_ROLE {
return "", fmt.Errorf("cannot create mesh as a client")
}
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = "" var ifName string = ""
@ -277,10 +287,36 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
pubKey := s.HostParameters.PrivateKey.PublicKey() pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId) collisionCount := uint8(0)
if err != nil { var nodeIP net.IP
return err
// Perform Duplicate Address Detection with the nodes
// that are already in the network
for {
generatedIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId, collisionCount)
if err != nil {
return err
}
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
proposition := func(node MeshNode) bool {
ipNet := node.GetWgHost()
return ipNet.IP.Equal(nodeIP)
}
if lib.Contains(lib.MapValues(snapshot.GetNodes()), proposition) {
collisionCount++
} else {
nodeIP = generatedIP
break
}
} }
node := s.nodeFactory.Build(&MeshNodeFactoryParams{ node := s.nodeFactory.Build(&MeshNodeFactoryParams{
@ -320,7 +356,7 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
err := mesh.RemoveNode(s.HostParameters.GetPublicKey()) err := mesh.RemoveNode(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return err logging.Log.WriteErrorf(err.Error())
} }
if s.OnDelete != nil { if s.OnDelete != nil {
@ -348,7 +384,6 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
} }
s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...) s.cmdRunner.RunCommands(s.conf.BaseConfiguration.PostDown...)
return err return err
} }
@ -373,43 +408,36 @@ func (s *MeshManagerImpl) ApplyConfig() error {
return nil return nil
} }
err := s.configApplyer.ApplyConfig() return s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(meshId, description string) error {
meshes := s.GetMeshes() mesh := s.GetMesh(meshId)
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
}
} }
return nil if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
} }
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error { func (s *MeshManagerImpl) SetAlias(meshId, alias string) error {
meshes := s.GetMeshes() mesh := s.GetMesh(meshId)
for _, mesh := range meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil { if mesh == nil {
return err return fmt.Errorf("mesh %s does not exist", meshId)
}
}
} }
return nil
if !mesh.NodeExists(s.HostParameters.GetPublicKey()) {
return fmt.Errorf("node %s does not exist in the mesh", meshId)
}
return mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
} }
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes

View File

@ -3,10 +3,10 @@ package mesh
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/smegmesh/pkg/wg"
) )
func getMeshConfiguration() *conf.DaemonConfiguration { func getMeshConfiguration() *conf.DaemonConfiguration {
@ -24,8 +24,8 @@ func getMeshConfiguration() *conf.DaemonConfiguration {
Timeout: 5, Timeout: 5,
Profile: false, Profile: false,
StubWg: true, StubWg: true,
SyncRate: 2, SyncTime: 2,
KeepAliveTime: 60, HeartBeat: 60,
ClusterSize: 64, ClusterSize: 64,
InterClusterChance: 0.15, InterClusterChance: 0.15,
BranchRate: 3, BranchRate: 3,
@ -213,7 +213,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
} }
} }
func TestSetAlias(t *testing.T) { func TestSetAliasUpdatesAliasOfNode(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
alias := "Firpo" alias := "Firpo"
@ -221,14 +221,13 @@ func TestSetAlias(t *testing.T) {
Port: 5000, Port: 5000,
Conf: &conf.WgConfiguration{}, Conf: &conf.WgConfiguration{},
}) })
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId, MeshId: meshId,
WgPort: 5000, WgPort: 5000,
Endpoint: "abc.com:8080", Endpoint: "abc.com:8080",
}) })
err := manager.SetAlias(alias) err := manager.SetAlias(meshId, alias)
if err != nil { if err != nil {
t.Fatalf(`failed to set the alias`) t.Fatalf(`failed to set the alias`)
@ -245,7 +244,7 @@ func TestSetAlias(t *testing.T) {
} }
} }
func TestSetDescription(t *testing.T) { func TestSetDescriptionSetsTheDescriptionOfTheNode(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
description := "wooooo" description := "wooooo"
@ -254,23 +253,13 @@ func TestSetDescription(t *testing.T) {
Conf: &conf.WgConfiguration{}, Conf: &conf.WgConfiguration{},
}) })
meshId2, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5001,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
WgPort: 5000, WgPort: 5000,
Endpoint: "abc.com:8080", Endpoint: "abc.com:8080",
}) })
manager.AddSelf(&AddSelfParams{
MeshId: meshId2,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
err := manager.SetDescription(description) err := manager.SetDescription(meshId1, description)
if err != nil { if err != nil {
t.Fatalf(`failed to set the descriptions`) t.Fatalf(`failed to set the descriptions`)
@ -285,18 +274,7 @@ func TestSetDescription(t *testing.T) {
if description != self1.GetDescription() { if description != self1.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self1.GetDescription()) t.Fatalf(`description should be %s was %s`, description, self1.GetDescription())
} }
self2, err := manager.GetSelf(meshId2)
if err != nil {
t.Fatalf(`failed to set the description`)
}
if description != self2.GetDescription() {
t.Fatalf(`description should be %s was %s`, description, self2.GetDescription())
}
} }
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
@ -327,3 +305,68 @@ func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
t.Fatalf(`failed to update the timestamp`) t.Fatalf(`failed to update the timestamp`)
} }
} }
func TestAddServiceAddsServiceToTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
}
func TestRemoveServiceRemovesTheServiceFromTheMesh(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh(&CreateMeshParams{
Port: 5000,
Conf: &conf.WgConfiguration{},
})
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
WgPort: 5000,
Endpoint: "abc.com:8080",
})
serviceName := "hello"
manager.SetService(meshId1, serviceName, "dave")
self, err := manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; !ok {
t.Fatalf(`service not added`)
}
manager.RemoveService(meshId1, serviceName)
self, err = manager.GetSelf(meshId1)
if err != nil {
t.Fatalf(`error thrown %s:`, err.Error())
}
if _, ok := self.GetServices()[serviceName]; ok {
t.Fatalf(`service still exists`)
}
}

View File

@ -3,8 +3,8 @@ package mesh
import ( import (
"net" "net"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/smegmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
type RouteManager interface { type RouteManager interface {
@ -24,7 +24,7 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
continue continue
} }
self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) self, err := mesh1.GetNode(r.meshManager.GetPublicKey().String())
if err != nil { if err != nil {
return err return err
@ -90,11 +90,20 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
// Calculate the set different of each, working out routes to remove and to keep. // Calculate the set different of each, working out routes to remove and to keep.
for meshId, meshRoutes := range routes { for meshId, meshRoutes := range routes {
mesh := r.meshManager.GetMesh(meshId) mesh := meshes[meshId]
self, _ := r.meshManager.GetSelf(meshId)
toRemove := make([]Route, 0)
prevRoutes, _ := mesh.GetRoutes(NodeID(self)) self, err := mesh.GetNode(r.meshManager.GetPublicKey().String())
if err != nil {
return err
}
toRemove := make([]Route, 0)
prevRoutes, err := mesh.GetRoutes(NodeID(self))
if err != nil {
return err
}
for _, route := range prevRoutes { for _, route := range prevRoutes {
if !lib.Contains(meshRoutes, func(r Route) bool { if !lib.Contains(meshRoutes, func(r Route) bool {

View File

@ -5,8 +5,8 @@ import (
"net" "net"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -30,8 +30,8 @@ func (*MeshNodeStub) GetType() conf.NodeType {
} }
// GetServices implements MeshNode. // GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string { func (m *MeshNodeStub) GetServices() map[string]string {
return make(map[string]string) return m.services
} }
// GetAlias implements MeshNode. // GetAlias implements MeshNode.
@ -249,6 +249,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
routes: make([]Route, 0), routes: make([]Route, 0),
identifier: "abc", identifier: "abc",
description: "A Mesh Node Stub", description: "A Mesh Node Stub",
services: make(map[string]string),
} }
} }
@ -271,32 +272,32 @@ type MeshManagerStub struct {
// GetRouteManager implements MeshManager. // GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager { func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented") return nil
} }
// GetNode implements MeshManager. // GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode { func (*MeshManagerStub) GetNode(meshId, nodeId string) MeshNode {
panic("unimplemented") return nil
} }
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error { func (*MeshManagerStub) RemoveService(meshId, service string) error {
panic("unimplemented") return nil
} }
// SetService implements MeshManager. // SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error { func (*MeshManagerStub) SetService(meshId, service, value string) error {
panic("unimplemented") return nil
} }
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error { func (*MeshManagerStub) SetAlias(meshId, alias string) error {
panic("unimplemented") return nil
} }
// Close implements MeshManager. // Close implements MeshManager.
func (*MeshManagerStub) Close() error { func (*MeshManagerStub) Close() error {
panic("unimplemented") return nil
} }
// Prune implements MeshManager. // Prune implements MeshManager.
@ -348,7 +349,7 @@ func (m *MeshManagerStub) ApplyConfig() error {
return nil return nil
} }
func (m *MeshManagerStub) SetDescription(description string) error { func (m *MeshManagerStub) SetDescription(meshId, description string) error {
return nil return nil
} }

View File

@ -6,7 +6,7 @@ import (
"net" "net"
"slices" "slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -81,6 +81,10 @@ func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey() key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey() key2, _ := node2.GetPublicKey()
if node1 == nil || node2 == nil {
return false
}
return key1.String() == key2.String() return key1.String() == key2.String()
} }

View File

@ -6,9 +6,9 @@ import (
"strings" "strings"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Querier queries a data store for the given data // Querier queries a data store for the given data

View File

@ -7,11 +7,11 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
type IpcHandler struct { type IpcHandler struct {
@ -43,10 +43,6 @@ func getOverrideConfiguration(args *ipc.WireGuardArgs) conf.WgConfiguration {
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs) overrideConf := getOverrideConfiguration(&args.WgArgs)
if overrideConf.Role != nil && *overrideConf.Role == conf.CLIENT_ROLE {
return fmt.Errorf("cannot create a mesh with no public endpoint")
}
meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{ meshId, err := n.Server.GetMeshManager().CreateMesh(&mesh.CreateMeshParams{
Port: args.WgArgs.WgPort, Port: args.WgArgs.WgPort,
Conf: &overrideConf, Conf: &overrideConf,
@ -83,10 +79,14 @@ func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
return nil return nil
} }
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { func (n *IpcHandler) JoinMesh(args *ipc.JoinMeshArgs, reply *string) error {
overrideConf := getOverrideConfiguration(&args.WgArgs) overrideConf := getOverrideConfiguration(&args.WgArgs)
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAdress) if n.Server.GetMeshManager().GetMesh(args.MeshId) != nil {
return fmt.Errorf("user is already apart of the mesh")
}
peerConnection, err := n.Server.GetConnectionManager().GetConnection(args.IpAddress)
if err != nil { if err != nil {
return err return err
@ -147,7 +147,6 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
if err == nil { if err == nil {
*reply = fmt.Sprintf("Left Mesh %s", meshId) *reply = fmt.Sprintf("Left Mesh %s", meshId)
} }
return err return err
} }
@ -182,19 +181,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil return nil
} }
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
result, err := g.Generate(meshId)
if err != nil {
return err
}
*reply = result
return nil
}
func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error { func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query) queryResponse, err := n.Server.GetQuerier().Query(params.MeshId, params.Query)
@ -206,30 +192,34 @@ func (n *IpcHandler) Query(params ipc.QueryMesh, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) PutDescription(description string, reply *string) error { func (n *IpcHandler) PutDescription(args ipc.PutDescriptionArgs, reply *string) error {
err := n.Server.GetMeshManager().SetDescription(description) err := n.Server.GetMeshManager().SetDescription(args.MeshId, args.Description)
if err != nil { if err != nil {
return err return err
} }
*reply = fmt.Sprintf("Set description to %s", description) *reply = fmt.Sprintf("set description to %s for %s", args.Description, args.MeshId)
return nil return nil
} }
func (n *IpcHandler) PutAlias(alias string, reply *string) error { func (n *IpcHandler) PutAlias(args ipc.PutAliasArgs, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias) if args.Alias == "" {
return fmt.Errorf("alias not provided")
}
err := n.Server.GetMeshManager().SetAlias(args.MeshId, args.Alias)
if err != nil { if err != nil {
return err return err
} }
*reply = fmt.Sprintf("Set alias to %s", alias) *reply = fmt.Sprintf("Set alias to %s", args.Alias)
return nil return nil
} }
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error { func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value) err := n.Server.GetMeshManager().SetService(service.MeshId, service.Service, service.Value)
if err != nil { if err != nil {
return err return err
@ -239,8 +229,8 @@ func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error
return nil return nil
} }
func (n *IpcHandler) DeleteService(service string, reply *string) error { func (n *IpcHandler) DeleteService(service ipc.DeleteServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service) err := n.Server.GetMeshManager().RemoveService(service.MeshId, service.Service)
if err != nil { if err != nil {
return err return err

View File

@ -3,10 +3,10 @@ package robin
import ( import (
"testing" "testing"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/smegmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
func getRequester() *IpcHandler { func getRequester() *IpcHandler {

View File

@ -4,8 +4,8 @@ import (
"context" "context"
"errors" "errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
type WgRpc struct { type WgRpc struct {

View File

@ -1 +0,0 @@
package robin

View File

@ -1,7 +1,7 @@
package route package route
import ( import (
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )

View File

@ -1,22 +1,22 @@
package sync package sync
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
"sync"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/smegmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
) )
// Syncer: picks random nodes from the meshs // Syncer: picks random nodes from the meshs
type Syncer interface { type Syncer interface {
Sync(meshId string) error Sync(theMesh mesh.MeshProvider) error
SyncMeshes() error SyncMeshes() error
} }
@ -27,54 +27,71 @@ type SyncerImpl struct {
syncCount int syncCount int
cluster conn.ConnCluster cluster conn.ConnCluster
conf *conf.DaemonConfiguration conf *conf.DaemonConfiguration
lastSync map[string]uint64 lastSync map[string]int64
lock sync.RWMutex
} }
// Sync: Sync random nodes // Sync: Sync with random nodes
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(correspondingMesh mesh.MeshProvider) error {
// Self can be nil if the node is removed if correspondingMesh == nil {
self, _ := s.manager.GetSelf(meshId) return fmt.Errorf("mesh provided was nil cannot sync nil mesh")
}
correspondingMesh := s.manager.GetMesh(meshId) // Self can be nil if the node is removed
selfID := s.manager.GetPublicKey()
self, err := correspondingMesh.GetNode(selfID.String())
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
correspondingMesh.Prune() correspondingMesh.Prune()
if self != nil && self.GetType() == conf.PEER_ROLE && !s.manager.HasChanges(meshId) && s.infectionCount == 0 { if correspondingMesh.HasChanges() {
logging.Log.WriteInfof("No changes for %s", meshId) logging.Log.WriteInfof("meshes %s has changes", correspondingMesh.GetMeshId())
}
// If not synchronised in certain pull from random neighbour // If removed sync with other nodes to gossip the node is removed
if uint64(time.Now().Unix())-s.lastSync[meshId] > 20 { if self != nil && self.GetType() == conf.PEER_ROLE && !correspondingMesh.HasChanges() && s.infectionCount == 0 {
return s.Pull(meshId) logging.Log.WriteInfof("no changes for %s", correspondingMesh.GetMeshId())
// If not synchronised in certain time pull from random neighbour
if s.conf.PullTime != 0 && time.Now().Unix()-s.lastSync[correspondingMesh.GetMeshId()] > int64(s.conf.PullTime) {
return s.Pull(self, correspondingMesh)
} }
return nil return nil
} }
before := time.Now() before := time.Now()
s.manager.GetRouteManager().UpdateRoutes() err = s.manager.GetRouteManager().UpdateRoutes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
publicKey := s.manager.GetPublicKey() publicKey := s.manager.GetPublicKey()
logging.Log.WriteInfof(publicKey.String())
nodeNames := correspondingMesh.GetPeers() nodeNames := correspondingMesh.GetPeers()
if self != nil { nodeNames = lib.Filter(nodeNames, func(s string) bool {
nodeNames = lib.Filter(nodeNames, func(s string) bool { // Filter our only public key out so we dont sync with ourself
return s != mesh.NodeID(self) return s != publicKey.String()
}) })
}
var gossipNodes []string var gossipNodes []string
// Clients always pings its peer for configuration // Clients always pings its peer for configuration
if self != nil && self.GetType() == conf.CLIENT_ROLE { if self != nil && self.GetType() == conf.CLIENT_ROLE && len(nodeNames) > 1 {
keyFunc := lib.HashString neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
bucketFunc := lib.HashString
neighbour := lib.ConsistentHash(nodeNames, publicKey.String(), keyFunc, bucketFunc) if len(neighbours) == 0 {
gossipNodes = make([]string, 1) return nil
gossipNodes[0] = neighbour }
// Peer with 2 nodes so that there is redundnacy in
// the situation the node leaves pre-emptively
redundancyLength := min(len(neighbours), 2)
gossipNodes = neighbours[:redundancyLength]
} else { } else {
neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String()) neighbours := s.cluster.GetNeighbours(nodeNames, publicKey.String())
gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) gossipNodes = lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
@ -88,65 +105,48 @@ func (s *SyncerImpl) Sync(meshId string) error {
// Do this synchronously to conserve bandwidth // Do this synchronously to conserve bandwidth
for _, node := range gossipNodes { for _, node := range gossipNodes {
correspondingPeer := s.manager.GetNode(meshId, node) correspondingPeer, err := correspondingMesh.GetNode(node)
if correspondingPeer == nil { if correspondingPeer == nil || err != nil {
logging.Log.WriteErrorf("node %s does not exist", node) logging.Log.WriteErrorf("node %s does not exist", node)
continue
} }
err := s.requester.SyncMesh(meshId, correspondingPeer) err = s.requester.SyncMesh(correspondingMesh.GetMeshId(), correspondingPeer)
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
succeeded = true succeeded = true
} else { }
// If the synchronisation operation has failed them mark a gravestone
// preventing the peer from being re-contacted until it has updated if err != nil {
// itself logging.Log.WriteErrorf(err.Error())
s.manager.GetMesh(meshId).Mark(node)
} }
} }
s.syncCount++ s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before)) logging.Log.WriteInfof("sync time: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) logging.Log.WriteInfof("number of syncs: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
if !succeeded { if !succeeded {
// If could not gossip with anyone then repeat.
s.infectionCount++ s.infectionCount++
} }
s.manager.GetMesh(meshId).SaveChanges() correspondingMesh.SaveChanges()
s.lastSync[meshId] = uint64(time.Now().Unix())
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
s.lock.Lock()
s.lastSync[correspondingMesh.GetMeshId()] = time.Now().Unix()
s.lock.Unlock()
return nil return nil
} }
// Pull one node in the cluster, if there has not been message dissemination // Pull one node in the cluster, if there has not been message dissemination
// in a certain period of time pull a random node within the cluster // in a certain period of time pull a random node within the cluster
func (s *SyncerImpl) Pull(meshId string) error { func (s *SyncerImpl) Pull(self mesh.MeshNode, mesh mesh.MeshProvider) error {
mesh := s.manager.GetMesh(meshId) peers := mesh.GetPeers()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
pubKey, _ := self.GetPublicKey() pubKey, _ := self.GetPublicKey()
if mesh == nil {
return errors.New("mesh is nil, invalid operation")
}
peers := mesh.GetPeers()
neighbours := s.cluster.GetNeighbours(peers, pubKey.String()) neighbours := s.cluster.GetNeighbours(peers, pubKey.String())
neighbour := lib.RandomSubsetOfLength(neighbours, 1) neighbour := lib.RandomSubsetOfLength(neighbours, 1)
@ -155,7 +155,7 @@ func (s *SyncerImpl) Pull(meshId string) error {
return nil return nil
} }
logging.Log.WriteInfof("PULLING from node %s", neighbour[0]) logging.Log.WriteInfof("pulling from node %s", neighbour[0])
pullNode, err := mesh.GetNode(neighbour[0]) pullNode, err := mesh.GetNode(neighbour[0])
@ -163,10 +163,10 @@ func (s *SyncerImpl) Pull(meshId string) error {
return fmt.Errorf("node %s does not exist in the mesh", neighbour[0]) return fmt.Errorf("node %s does not exist in the mesh", neighbour[0])
} }
err = s.requester.SyncMesh(meshId, pullNode) err = s.requester.SyncMesh(mesh.GetMeshId(), pullNode)
if err == nil || err == io.EOF { if err == nil || err == io.EOF {
s.lastSync[meshId] = uint64(time.Now().Unix()) s.lastSync[mesh.GetMeshId()] = time.Now().Unix()
} else { } else {
return err return err
} }
@ -177,14 +177,31 @@ func (s *SyncerImpl) Pull(meshId string) error {
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error { func (s *SyncerImpl) SyncMeshes() error {
for meshId := range s.manager.GetMeshes() { var wg sync.WaitGroup
err := s.Sync(meshId)
if err != nil { for _, mesh := range s.manager.GetMeshes() {
logging.Log.WriteErrorf(err.Error()) wg.Add(1)
sync := func() {
defer wg.Done()
err := s.Sync(mesh)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
} }
}
go sync()
}
wg.Wait()
logging.Log.WriteInfof("updating the WireGuard configuration")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("failed to update config %w", err)
}
return nil return nil
} }
@ -197,5 +214,5 @@ func NewSyncer(m mesh.MeshManager, conf *conf.DaemonConfiguration, r SyncRequest
infectionCount: 0, infectionCount: 0,
syncCount: 0, syncCount: 0,
cluster: cluster, cluster: cluster,
lastSync: make(map[string]uint64)} lastSync: make(map[string]int64)}
} }

View File

@ -1,8 +1,9 @@
package sync package sync
import ( import (
logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/smegmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/smegmesh/pkg/mesh"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
) )
@ -15,11 +16,34 @@ type SyncErrorHandler interface {
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler // SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct { type SyncErrorHandlerImpl struct {
meshManager mesh.MeshManager meshManager mesh.MeshManager
connManager conn.ConnectionManager
} }
func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool { func (s *SyncErrorHandlerImpl) handleFailed(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId) mesh := s.meshManager.GetMesh(meshId)
mesh.Mark(nodeId) mesh.Mark(nodeId)
node, err := mesh.GetNode(nodeId)
if err != nil {
s.connManager.RemoveConnection(node.GetHostEndpoint())
}
return true
}
func (s *SyncErrorHandlerImpl) handleDeadlineExceeded(meshId string, nodeId string) bool {
mesh := s.meshManager.GetMesh(meshId)
if mesh == nil {
return true
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return false
}
s.connManager.RemoveConnection(node.GetHostEndpoint())
return true return true
} }
@ -29,13 +53,15 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, nodeId string, err error) b
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
switch errStatus.Code() { switch errStatus.Code() {
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: case codes.Unavailable, codes.Unknown, codes.Internal, codes.NotFound:
return s.handleFailed(meshId, nodeId) return s.handleFailed(meshId, nodeId)
case codes.DeadlineExceeded:
return s.handleDeadlineExceeded(meshId, nodeId)
} }
return false return false
} }
func NewSyncErrorHandler(m mesh.MeshManager) SyncErrorHandler { func NewSyncErrorHandler(m mesh.MeshManager, conn conn.ConnectionManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m} return &SyncErrorHandlerImpl{meshManager: m, connManager: conn}
} }

View File

@ -6,10 +6,10 @@ import (
"io" "io"
"time" "time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
// SyncRequester: coordinates the syncing of meshes // SyncRequester: coordinates the syncing of meshes
@ -91,7 +91,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
c := rpc.NewSyncServiceClient(client) c := rpc.NewSyncServiceClient(client)
syncTimeOut := float64(s.server.Conf.SyncRate) * float64(time.Second) syncTimeOut := float64(s.server.Conf.SyncTime) * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut)) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel() defer cancel()
@ -99,11 +99,11 @@ func (s *SyncRequesterImpl) SyncMesh(meshId string, meshNode mesh.MeshNode) erro
err = s.syncMesh(mesh, ctx, c) err = s.syncMesh(mesh, ctx, c)
if err != nil { if err != nil {
return s.handleErr(meshId, pubKey.String(), err) s.handleErr(meshId, pubKey.String(), err)
} }
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
return nil return err
} }
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error { func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
@ -151,6 +151,6 @@ func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context
} }
func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester { func NewSyncRequester(s *ctrlserver.MeshCtrlServer) SyncRequester {
errorHdlr := NewSyncErrorHandler(s.MeshManager) errorHdlr := NewSyncErrorHandler(s.MeshManager, s.ConnectionManager)
return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr} return &SyncRequesterImpl{server: s, errorHdlr: errorHdlr}
} }

View File

@ -1,8 +1,8 @@
package sync package sync
import ( import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
) )
// Run implements SyncScheduler. // Run implements SyncScheduler.
@ -14,5 +14,5 @@ func syncFunction(syncer Syncer) lib.TimerFunc {
} }
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer { func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester, syncer Syncer) *lib.Timer {
return lib.NewTimer(syncFunction(syncer), s.Conf.SyncRate) return lib.NewTimer(syncFunction(syncer), s.Conf.SyncTime)
} }

View File

@ -6,9 +6,9 @@ import (
"errors" "errors"
"io" "io"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/smegmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/smegmesh/pkg/rpc"
) )
type SyncServiceImpl struct { type SyncServiceImpl struct {

View File

@ -1,9 +1,9 @@
package timer package timer
import ( import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/smegmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
) )
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer { func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
@ -11,5 +11,5 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
logging.Log.WriteInfof("Updated Timestamp") logging.Log.WriteInfof("Updated Timestamp")
return ctrlServer.MeshManager.UpdateTimeStamp() return ctrlServer.MeshManager.UpdateTimeStamp()
} }
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) return *lib.NewTimer(timerFunc, ctrlServer.Conf.HeartBeat)
} }

View File

@ -5,8 +5,8 @@ import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/smegmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/smegmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )