mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2024-12-04 21:50:49 +01:00
Interfacing out components for unit testing
This commit is contained in:
parent
f1cfd52a91
commit
4c6bbcffcd
@ -22,7 +22,7 @@ func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string {
|
||||
WgPort: wgPort,
|
||||
}
|
||||
|
||||
err := client.Call("RobinIpc.CreateMesh", &newMeshParams, &reply)
|
||||
err := client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
|
||||
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
@ -34,7 +34,7 @@ func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string {
|
||||
func listMeshes(client *ipcRpc.Client) {
|
||||
reply := new(ipc.ListMeshReply)
|
||||
|
||||
err := client.Call("RobinIpc.ListMeshes", "", &reply)
|
||||
err := client.Call("IpcHandler.ListMeshes", "", &reply)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
@ -56,7 +56,7 @@ func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName str
|
||||
Port: wgPort,
|
||||
}
|
||||
|
||||
err := client.Call("RobinIpc.JoinMesh", &args, &reply)
|
||||
err := client.Call("IpcHandler.JoinMesh", &args, &reply)
|
||||
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
@ -68,7 +68,7 @@ func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName str
|
||||
func getMesh(client *ipcRpc.Client, meshId string) {
|
||||
reply := new(ipc.GetMeshReply)
|
||||
|
||||
err := client.Call("RobinIpc.GetMesh", &meshId, &reply)
|
||||
err := client.Call("IpcHandler.GetMesh", &meshId, &reply)
|
||||
|
||||
if err != nil {
|
||||
log.Panic(err.Error())
|
||||
@ -92,7 +92,7 @@ func getMesh(client *ipcRpc.Client, meshId string) {
|
||||
func enableInterface(client *ipcRpc.Client, meshId string) {
|
||||
var reply string
|
||||
|
||||
err := client.Call("RobinIpc.EnableInterface", &meshId, &reply)
|
||||
err := client.Call("IpcHandler.EnableInterface", &meshId, &reply)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
@ -105,7 +105,7 @@ func enableInterface(client *ipcRpc.Client, meshId string) {
|
||||
func getGraph(client *ipcRpc.Client, meshId string) {
|
||||
var reply string
|
||||
|
||||
err := client.Call("RobinIpc.GetDOT", &meshId, &reply)
|
||||
err := client.Call("IpcHandler.GetDOT", &meshId, &reply)
|
||||
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
|
@ -28,8 +28,8 @@ func main() {
|
||||
return
|
||||
}
|
||||
|
||||
var robinRpc robin.RobinRpc
|
||||
var robinIpc robin.RobinIpc
|
||||
var robinRpc robin.WgRpc
|
||||
var robinIpc robin.IpcHandler
|
||||
var authProvider middleware.AuthRpcProvider
|
||||
var syncProvider sync.SyncServiceImpl
|
||||
|
||||
|
@ -2,21 +2,22 @@ package crdt
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/automerge/automerge-go"
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// CrdtNodeManager manages nodes in the crdt mesh
|
||||
type CrdtNodeManager struct {
|
||||
// CrdtMeshManager manages nodes in the crdt mesh
|
||||
type CrdtMeshManager struct {
|
||||
MeshId string
|
||||
IfName string
|
||||
NodeId string
|
||||
@ -26,57 +27,63 @@ type CrdtNodeManager struct {
|
||||
conf *conf.WgMeshConfiguration
|
||||
}
|
||||
|
||||
const maxFails = 5
|
||||
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
|
||||
crdt, ok := node.(*MeshNodeCrdt)
|
||||
|
||||
if !ok {
|
||||
panic("node must be of type *MeshNodeCrdt")
|
||||
}
|
||||
|
||||
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
|
||||
crdt.FailedMap = automerge.NewMap()
|
||||
crdt.Timestamp = time.Now().Unix()
|
||||
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
|
||||
nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint)
|
||||
nodeVal.Map().Set("routes", automerge.NewMap())
|
||||
}
|
||||
|
||||
func (c *CrdtNodeManager) ApplyWg() error {
|
||||
snapshot, err := c.GetCrdt()
|
||||
func (c *CrdtMeshManager) ApplyWg() error {
|
||||
// snapshot, err := c.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
c.updateWgConf(c.IfName, snapshot.Nodes, *c.Client)
|
||||
// c.updateWgConf(c.IfName, snapshot.GetNodes(), *c.Client)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCrdt(): Converts the document into a struct
|
||||
func (c *CrdtNodeManager) GetCrdt() (*MeshCrdt, error) {
|
||||
// GetMesh(): Converts the document into a struct
|
||||
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
|
||||
return automerge.As[*MeshCrdt](c.doc.Root())
|
||||
}
|
||||
|
||||
// GetMeshId returns the meshid of the mesh
|
||||
func (c *CrdtMeshManager) GetMeshId() string {
|
||||
return c.MeshId
|
||||
}
|
||||
|
||||
// Save: Save an entire mesh network
|
||||
func (c *CrdtMeshManager) Save() []byte {
|
||||
return c.doc.Save()
|
||||
}
|
||||
|
||||
// Load: Load an entire mesh network
|
||||
func (c *CrdtNodeManager) Load(bytes []byte) error {
|
||||
func (c *CrdtMeshManager) Load(bytes []byte) error {
|
||||
doc, err := automerge.Load(bytes)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.doc = doc
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save: Save an entire mesh network
|
||||
func (c *CrdtNodeManager) Save() []byte {
|
||||
return c.doc.Save()
|
||||
}
|
||||
|
||||
// NewCrdtNodeManager: Create a new crdt node manager
|
||||
func NewCrdtNodeManager(meshId, hostId, devName string, port int, conf conf.WgMeshConfiguration, client *wgctrl.Client) (*CrdtNodeManager, error) {
|
||||
var manager CrdtNodeManager
|
||||
func NewCrdtNodeManager(meshId, devName string, port int, conf conf.WgMeshConfiguration, client *wgctrl.Client) (*CrdtMeshManager, error) {
|
||||
var manager CrdtMeshManager
|
||||
manager.MeshId = meshId
|
||||
manager.doc = automerge.New()
|
||||
manager.IfName = devName
|
||||
manager.Client = client
|
||||
manager.NodeId = hostId
|
||||
manager.conf = &conf
|
||||
|
||||
err := wg.CreateWgInterface(client, devName, port)
|
||||
@ -88,7 +95,7 @@ func NewCrdtNodeManager(meshId, hostId, devName string, port int, conf conf.WgMe
|
||||
return &manager, nil
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) {
|
||||
func (m *CrdtMeshManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) {
|
||||
peerEndpoint, err := net.ResolveUDPAddr("udp", node.WgEndpoint)
|
||||
|
||||
if err != nil {
|
||||
@ -125,45 +132,7 @@ func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfi
|
||||
return &peerConfig, nil
|
||||
}
|
||||
|
||||
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
|
||||
return strings.Compare(m1.PublicKey, m2.PublicKey)
|
||||
}
|
||||
|
||||
func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error {
|
||||
node, err := c.doc.Path("nodes").Map().Get(endpoint)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
counterMap, err := node.Map().Get("failedMap")
|
||||
|
||||
if counterMap.Kind() == automerge.KindVoid {
|
||||
return errors.New("Something went wrong map does not exist")
|
||||
}
|
||||
|
||||
counter, _ := counterMap.Map().Get(c.NodeId)
|
||||
|
||||
if counter.Kind() == automerge.KindVoid {
|
||||
err = counterMap.Map().Set(c.NodeId, incAmount)
|
||||
} else {
|
||||
if counter.Int64()+incAmount < 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = counterMap.Map().Set(c.NodeId, counter.Int64()+1)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Increment failed count increments the number of times we have attempted
|
||||
// to contact the node and it's failed
|
||||
func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error {
|
||||
return c.changeFailedCount(c.MeshId, endpoint, +1)
|
||||
}
|
||||
|
||||
func (c *CrdtNodeManager) removeNode(endpoint string) error {
|
||||
func (c *CrdtMeshManager) removeNode(endpoint string) error {
|
||||
err := c.doc.Path("nodes").Map().Delete(endpoint)
|
||||
|
||||
if err != nil {
|
||||
@ -173,14 +142,8 @@ func (c *CrdtNodeManager) removeNode(endpoint string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrement failed count decrements the number of times we have attempted to
|
||||
// contact the node and it's failed
|
||||
func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error {
|
||||
return c.changeFailedCount(c.MeshId, endpoint, -1)
|
||||
}
|
||||
|
||||
// GetNode: returns a mesh node crdt.
|
||||
func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
|
||||
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
|
||||
node, err := m.doc.Path("nodes").Map().Get(endpoint)
|
||||
|
||||
if err != nil {
|
||||
@ -196,11 +159,11 @@ func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
|
||||
return meshNode, nil
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) Length() int {
|
||||
func (m *CrdtMeshManager) Length() int {
|
||||
return m.doc.Path("nodes").Map().Len()
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) {
|
||||
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
|
||||
dev, err := m.Client.Device(m.IfName)
|
||||
|
||||
if err != nil {
|
||||
@ -211,7 +174,7 @@ func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) {
|
||||
}
|
||||
|
||||
// HasChanges returns true if we have changes since the last time we synced
|
||||
func (m *CrdtNodeManager) HasChanges() bool {
|
||||
func (m *CrdtMeshManager) HasChanges() bool {
|
||||
changes, err := m.doc.Changes(m.LastHash)
|
||||
|
||||
logging.Log.WriteInfof("Changes %s", m.LastHash.String())
|
||||
@ -224,34 +187,11 @@ func (m *CrdtNodeManager) HasChanges() bool {
|
||||
return len(changes) > 0
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
|
||||
node, err := m.GetNode(endpoint)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteErrorf("Cannot get node node: %s\n", endpoint)
|
||||
return true
|
||||
func (m *CrdtMeshManager) HasFailed(endpoint string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
values, err := node.FailedMap.Values()
|
||||
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
|
||||
countFailed := 0
|
||||
|
||||
for _, value := range values {
|
||||
count := value.Int64()
|
||||
|
||||
if count >= 1 {
|
||||
countFailed++
|
||||
}
|
||||
}
|
||||
|
||||
return countFailed >= 4
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) SaveChanges() {
|
||||
func (m *CrdtMeshManager) SaveChanges() {
|
||||
hashes := m.doc.Heads()
|
||||
hash := hashes[len(hashes)-1]
|
||||
|
||||
@ -259,13 +199,17 @@ func (m *CrdtNodeManager) SaveChanges() {
|
||||
m.LastHash = hash
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) UpdateTimeStamp() error {
|
||||
node, err := m.doc.Path("nodes").Map().Get(m.NodeId)
|
||||
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
|
||||
node, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if node.Kind() != automerge.KindMap {
|
||||
return errors.New("node is not a map")
|
||||
}
|
||||
|
||||
err = node.Map().Set("timestamp", time.Now().Unix())
|
||||
|
||||
if err == nil {
|
||||
@ -275,7 +219,32 @@ func (m *CrdtNodeManager) UpdateTimeStamp() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
|
||||
// AddRoutes: adds routes to the specific nodeId
|
||||
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
|
||||
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routeMap, err := nodeVal.Map().Get("routes")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
err = routeMap.Map().Set(route, struct{}{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CrdtMeshManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
|
||||
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
|
||||
|
||||
var count int = 0
|
||||
@ -300,35 +269,58 @@ func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNode
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddRoutes: adds routes to the specific nodeId
|
||||
func (m *CrdtNodeManager) AddRoutes(routes ...string) error {
|
||||
nodeVal, err := m.doc.Path("nodes").Map().Get(m.NodeId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routeMap, err := nodeVal.Map().Get("routes")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, route := range routes {
|
||||
err = routeMap.Map().Set(route, struct{}{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *CrdtNodeManager) GetSyncer() *AutomergeSync {
|
||||
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
|
||||
return NewAutomergeSync(m)
|
||||
}
|
||||
|
||||
func (n *MeshNodeCrdt) GetEscapedIP() string {
|
||||
return fmt.Sprintf("\"%s\"", n.WgHost)
|
||||
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
|
||||
return strings.Compare(m1.PublicKey, m2.PublicKey)
|
||||
}
|
||||
|
||||
func (m *MeshNodeCrdt) GetHostEndpoint() string {
|
||||
return m.HostEndpoint
|
||||
}
|
||||
|
||||
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
|
||||
return wgtypes.ParseKey(m.PublicKey)
|
||||
}
|
||||
|
||||
func (m *MeshNodeCrdt) GetWgEndpoint() string {
|
||||
return m.HostEndpoint
|
||||
}
|
||||
|
||||
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
|
||||
_, ipnet, err := net.ParseCIDR(m.WgHost)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
|
||||
return nil
|
||||
}
|
||||
|
||||
return ipnet
|
||||
}
|
||||
|
||||
func (m *MeshNodeCrdt) GetTimeStamp() int64 {
|
||||
return m.Timestamp
|
||||
}
|
||||
|
||||
func (m *MeshNodeCrdt) GetRoutes() []string {
|
||||
return lib.MapKeys(m.Routes)
|
||||
}
|
||||
|
||||
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
|
||||
nodes := make(map[string]mesh.MeshNode)
|
||||
|
||||
for _, node := range m.Nodes {
|
||||
nodes[node.HostEndpoint] = &MeshNodeCrdt{
|
||||
HostEndpoint: node.HostEndpoint,
|
||||
WgEndpoint: node.WgEndpoint,
|
||||
PublicKey: node.PublicKey,
|
||||
WgHost: node.WgHost,
|
||||
Timestamp: node.Timestamp,
|
||||
Routes: node.Routes,
|
||||
}
|
||||
}
|
||||
|
||||
return nodes
|
||||
}
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
|
||||
type AutomergeSync struct {
|
||||
state *automerge.SyncState
|
||||
manager *CrdtNodeManager
|
||||
manager *CrdtMeshManager
|
||||
}
|
||||
|
||||
func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
|
||||
@ -35,7 +35,7 @@ func (a *AutomergeSync) Complete() {
|
||||
a.manager.SaveChanges()
|
||||
}
|
||||
|
||||
func NewAutomergeSync(manager *CrdtNodeManager) *AutomergeSync {
|
||||
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
|
||||
return &AutomergeSync{
|
||||
state: automerge.NewSyncState(manager.doc),
|
||||
manager: manager,
|
||||
|
10
pkg/automerge/automergefactory.go
Normal file
10
pkg/automerge/automergefactory.go
Normal file
@ -0,0 +1,10 @@
|
||||
package crdt
|
||||
|
||||
import "github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||
|
||||
type CrdtProviderFactory struct{}
|
||||
|
||||
func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
|
||||
return NewCrdtNodeManager(params.MeshId, params.DevName, params.Port,
|
||||
*params.Conf, params.Client)
|
||||
}
|
@ -1,7 +1,5 @@
|
||||
package crdt
|
||||
|
||||
import "github.com/automerge/automerge-go"
|
||||
|
||||
// MeshNodeCrdt: Represents a CRDT for a mesh nodes
|
||||
type MeshNodeCrdt struct {
|
||||
HostEndpoint string `automerge:"hostEndpoint"`
|
||||
@ -9,7 +7,6 @@ type MeshNodeCrdt struct {
|
||||
PublicKey string `automerge:"publicKey"`
|
||||
WgHost string `automerge:"wgHost"`
|
||||
Timestamp int64 `automerge:"timestamp"`
|
||||
FailedMap *automerge.Map `automerge:"failedMap"`
|
||||
Routes map[string]interface{} `automerge:"routes"`
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,11 @@ type WgMeshConfiguration struct {
|
||||
PrivateKeyPath string `yaml:"privateKeyPath"`
|
||||
SkipCertVerification bool `yaml:"skipCertVerification"`
|
||||
GrpcPort string `yaml:"gRPCPort"`
|
||||
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
|
||||
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
|
||||
// PublicEndpoint is the IP in which this computer is publicly reachable.
|
||||
// usecase is when the node is behind NAT.
|
||||
PublicEndpoint string `yaml:"publicEndpoint"`
|
||||
}
|
||||
|
||||
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package ctrlserver
|
||||
|
||||
import (
|
||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/conn"
|
||||
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||
@ -21,7 +22,8 @@ type NewCtrlServerParams struct {
|
||||
// operation failed
|
||||
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
||||
ctrlServer := new(MeshCtrlServer)
|
||||
ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client)
|
||||
factory := crdt.CrdtProviderFactory{}
|
||||
ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client, &factory)
|
||||
ctrlServer.Conf = params.Conf
|
||||
|
||||
connManagerParams := conn.NewConnectionManageParams{
|
||||
|
@ -16,7 +16,6 @@ type MeshNode struct {
|
||||
WgEndpoint string
|
||||
PublicKey string
|
||||
WgHost string
|
||||
Failed bool
|
||||
Timestamp int64
|
||||
Routes []string
|
||||
}
|
||||
@ -32,7 +31,7 @@ type Mesh struct {
|
||||
*/
|
||||
type MeshCtrlServer struct {
|
||||
Client *wgctrl.Client
|
||||
MeshManager *mesh.MeshManger
|
||||
MeshManager *mesh.MeshManager
|
||||
ConnectionManager conn.ConnectionManager
|
||||
ConnectionServer *conn.ConnectionServer
|
||||
Conf *conf.WgMeshConfiguration
|
||||
|
@ -1,15 +0,0 @@
|
||||
/*
|
||||
* RPC component of the server
|
||||
*/
|
||||
package ctrlserver
|
||||
|
||||
import (
|
||||
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func NewRpcServer(server rpc.MeshCtrlServerServer) *grpc.Server {
|
||||
grpc := grpc.NewServer()
|
||||
rpc.RegisterMeshCtrlServerServer(grpc, server)
|
||||
return grpc
|
||||
}
|
@ -1,5 +1,9 @@
|
||||
package lib
|
||||
|
||||
import (
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
)
|
||||
|
||||
// MapToSlice converts a map to a slice in go
|
||||
func MapValues[K comparable, V any](m map[K]V) []V {
|
||||
return MapValuesWithExclude(m, map[K]struct{}{})
|
||||
@ -19,6 +23,8 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
|
||||
continue
|
||||
}
|
||||
|
||||
logging.Log.WriteInfof("Key %s", k)
|
||||
|
||||
values[i] = v
|
||||
i++
|
||||
}
|
||||
|
@ -5,14 +5,13 @@ import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// GetOutboundIP: gets the oubound IP of this packet
|
||||
func GetOutboundIP() net.IP {
|
||||
conn, err := net.Dial("udp", "8.8.8.8:80")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
|
||||
return localAddr.IP
|
||||
}
|
||||
|
@ -8,13 +8,14 @@ import (
|
||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||
)
|
||||
|
||||
// MeshGraphConverter converts a mesh to a graph
|
||||
type MeshGraphConverter interface {
|
||||
// convert the mesh to textual form
|
||||
Generate(meshId string) (string, error)
|
||||
}
|
||||
|
||||
type MeshDOTConverter struct {
|
||||
manager *MeshManger
|
||||
manager *MeshManager
|
||||
}
|
||||
|
||||
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
|
||||
@ -26,35 +27,33 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
|
||||
|
||||
g := graph.NewGraph(meshId, graph.GRAPH)
|
||||
|
||||
snapshot, err := mesh.GetCrdt()
|
||||
snapshot, err := mesh.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
for _, node := range snapshot.Nodes {
|
||||
g.AddNode(node.GetEscapedIP())
|
||||
for _, node := range snapshot.GetNodes() {
|
||||
g.AddNode(fmt.Sprintf("\"%s\"", node.GetWgHost().IP.String()))
|
||||
}
|
||||
|
||||
nodes := lib.MapValues(snapshot.Nodes)
|
||||
nodes := lib.MapValues(snapshot.GetNodes())
|
||||
|
||||
for i, node1 := range nodes[:len(nodes)-1] {
|
||||
if mesh.HasFailed(node1.HostEndpoint) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, node2 := range nodes[i+1:] {
|
||||
if node1.WgEndpoint == node2.WgEndpoint || mesh.HasFailed(node2.HostEndpoint) {
|
||||
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
|
||||
continue
|
||||
}
|
||||
|
||||
g.AddEdge(fmt.Sprintf("%s to %s", node1.GetEscapedIP(), node2.GetEscapedIP()), node1.GetEscapedIP(), node2.GetEscapedIP())
|
||||
node1Id := fmt.Sprintf("\"%s\"", node1.GetWgHost().IP.String())
|
||||
node2Id := fmt.Sprintf("\"%s\"", node2.GetWgHost().IP.String())
|
||||
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
|
||||
}
|
||||
}
|
||||
|
||||
return g.GetDOT()
|
||||
}
|
||||
|
||||
func NewMeshDotConverter(m *MeshManger) MeshGraphConverter {
|
||||
func NewMeshDotConverter(m *MeshManager) MeshGraphConverter {
|
||||
return &MeshDOTConverter{manager: m}
|
||||
}
|
@ -1,159 +0,0 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type MeshManger struct {
|
||||
Meshes map[string]*crdt.CrdtNodeManager
|
||||
RouteManager RouteManager
|
||||
Client *wgctrl.Client
|
||||
HostEndpoint string
|
||||
conf *conf.WgMeshConfiguration
|
||||
}
|
||||
|
||||
func (m *MeshManger) MeshExists(meshId string) bool {
|
||||
_, inMesh := m.Meshes[meshId]
|
||||
return inMesh
|
||||
}
|
||||
|
||||
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
|
||||
func (m *MeshManger) CreateMesh(devName string, port int) (string, error) {
|
||||
key, err := wgtypes.GenerateKey()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nodeManager, err := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, port, *m.conf, m.Client)
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
m.Meshes[key.String()] = nodeManager
|
||||
|
||||
return key.String(), err
|
||||
}
|
||||
|
||||
// AddMesh: Add the mesh to the list of meshes
|
||||
func (m *MeshManger) AddMesh(meshId string, devName string, port int, meshBytes []byte) error {
|
||||
mesh, err := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, port, *m.conf, m.Client)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mesh.Load(meshBytes)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Meshes[meshId] = mesh
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMeshNode: Add a mesh node
|
||||
func (m *MeshManger) AddMeshNode(meshId string, node crdt.MeshNodeCrdt) {
|
||||
m.Meshes[meshId].AddNode(node)
|
||||
|
||||
if m.conf.AdvertiseRoutes {
|
||||
m.RouteManager.UpdateRoutes()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MeshManger) HasChanges(meshId string) bool {
|
||||
return m.Meshes[meshId].HasChanges()
|
||||
}
|
||||
|
||||
func (m *MeshManger) GetMesh(meshId string) *crdt.CrdtNodeManager {
|
||||
theMesh, _ := m.Meshes[meshId]
|
||||
return theMesh
|
||||
}
|
||||
|
||||
// EnableInterface: Enables the given WireGuard interface.
|
||||
func (s *MeshManger) EnableInterface(meshId string) error {
|
||||
mesh, contains := s.Meshes[meshId]
|
||||
|
||||
if !contains {
|
||||
return errors.New("Mesh does not exist")
|
||||
}
|
||||
|
||||
crdt, err := mesh.GetCrdt()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node, contains := crdt.Nodes[s.HostEndpoint]
|
||||
|
||||
if !contains {
|
||||
return errors.New("Node does not exist in the mesh")
|
||||
}
|
||||
|
||||
err = mesh.ApplyWg()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = wg.EnableInterface(mesh.IfName, node.WgHost)
|
||||
|
||||
if s.conf.AdvertiseRoutes {
|
||||
s.RouteManager.ApplyWg(mesh)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPublicKey: Gets the public key of the WireGuard mesh
|
||||
func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
|
||||
mesh, ok := s.Meshes[meshId]
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("mesh does not exist")
|
||||
}
|
||||
|
||||
dev, err := mesh.GetDevice()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dev.PublicKey, nil
|
||||
}
|
||||
|
||||
// UpdateTimeStamp updates the timestamp of this node in all meshes
|
||||
func (s *MeshManger) UpdateTimeStamp() error {
|
||||
for _, mesh := range s.Meshes {
|
||||
err := mesh.UpdateTimeStamp()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client) *MeshManger {
|
||||
ip := lib.GetOutboundIP()
|
||||
m := &MeshManger{
|
||||
Meshes: make(map[string]*crdt.CrdtNodeManager),
|
||||
HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort),
|
||||
Client: client,
|
||||
conf: &conf,
|
||||
}
|
||||
|
||||
m.RouteManager = NewRouteManager(m)
|
||||
return m
|
||||
}
|
100
pkg/mesh/meshconfig.go
Normal file
100
pkg/mesh/meshconfig.go
Normal file
@ -0,0 +1,100 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// MeshConfigApplyer abstracts applying the mesh configuration
|
||||
type MeshConfigApplyer interface {
|
||||
ApplyConfig() error
|
||||
}
|
||||
|
||||
// WgMeshConfigApplyer applies WireGuard configuration
|
||||
type WgMeshConfigApplyer struct {
|
||||
meshManager *MeshManager
|
||||
}
|
||||
|
||||
func ConvertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
|
||||
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pubKey, err := node.GetPublicKey()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
allowedips := make([]net.IPNet, 1)
|
||||
allowedips[0] = *node.GetWgHost()
|
||||
|
||||
for _, route := range node.GetRoutes() {
|
||||
_, ipnet, _ := net.ParseCIDR(route)
|
||||
allowedips = append(allowedips, *ipnet)
|
||||
}
|
||||
|
||||
peerConfig := wgtypes.PeerConfig{
|
||||
PublicKey: pubKey,
|
||||
Endpoint: endpoint,
|
||||
AllowedIPs: allowedips,
|
||||
}
|
||||
|
||||
return &peerConfig, nil
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
|
||||
snap, err := mesh.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nodes := snap.GetNodes()
|
||||
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
|
||||
|
||||
var count int = 0
|
||||
|
||||
for _, n := range nodes {
|
||||
peer, err := ConvertMeshNode(n)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerConfigs[count] = *peer
|
||||
count++
|
||||
}
|
||||
|
||||
cfg := wgtypes.Config{
|
||||
Peers: peerConfigs,
|
||||
ReplacePeers: true,
|
||||
}
|
||||
|
||||
dev, err := mesh.GetDevice()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return m.meshManager.Client.ConfigureDevice(dev.Name, cfg)
|
||||
}
|
||||
|
||||
func (m *WgMeshConfigApplyer) ApplyConfig() error {
|
||||
for _, mesh := range m.meshManager.Meshes {
|
||||
err := m.updateWgConf(mesh)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewWgMeshConfigApplyer(manager *MeshManager) MeshConfigApplyer {
|
||||
return &WgMeshConfigApplyer{meshManager: manager}
|
||||
}
|
43
pkg/mesh/meshinterface.go
Normal file
43
pkg/mesh/meshinterface.go
Normal file
@ -0,0 +1,43 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||
)
|
||||
|
||||
// MeshInterfaces manipulates interfaces to do with meshes
|
||||
type MeshInterface interface {
|
||||
EnableInterface(meshId string) error
|
||||
}
|
||||
|
||||
type WgMeshInterface struct {
|
||||
manager *MeshManager
|
||||
}
|
||||
|
||||
// EnableInterface enables the interface at the given endpoint
|
||||
func (m *WgMeshInterface) EnableInterface(meshId string) error {
|
||||
mesh, ok := m.manager.Meshes[meshId]
|
||||
|
||||
if !ok {
|
||||
return errors.New("the provided mesh does not exist")
|
||||
}
|
||||
|
||||
dev, err := mesh.GetDevice()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
self, err := m.manager.GetSelf(meshId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return wg.EnableInterface(dev.Name, self.GetWgHost().String())
|
||||
}
|
||||
|
||||
func NewWgMeshInterface(manager *MeshManager) MeshInterface {
|
||||
return &WgMeshInterface{manager: manager}
|
||||
}
|
179
pkg/mesh/meshmanager.go
Normal file
179
pkg/mesh/meshmanager.go
Normal file
@ -0,0 +1,179 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
type MeshManager struct {
|
||||
Meshes map[string]MeshProvider
|
||||
RouteManager RouteManager
|
||||
Client *wgctrl.Client
|
||||
// HostParameters contains information that uniquely locates
|
||||
// the node in the mesh network.
|
||||
HostParameters *HostParameters
|
||||
conf *conf.WgMeshConfiguration
|
||||
meshProviderFactory MeshProviderFactory
|
||||
configApplyer MeshConfigApplyer
|
||||
interfaceEnabler MeshInterface
|
||||
}
|
||||
|
||||
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
|
||||
func (m *MeshManager) CreateMesh(devName string, port int) (string, error) {
|
||||
key, err := wgtypes.GenerateKey()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
||||
DevName: devName,
|
||||
Port: port,
|
||||
Conf: m.conf,
|
||||
Client: m.Client,
|
||||
MeshId: key.String(),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
m.Meshes[key.String()] = nodeManager
|
||||
|
||||
return key.String(), err
|
||||
}
|
||||
|
||||
// AddMesh: Add the mesh to the list of meshes
|
||||
func (m *MeshManager) AddMesh(meshId string, devName string, port int, meshBytes []byte) error {
|
||||
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
||||
DevName: devName,
|
||||
Port: port,
|
||||
Conf: m.conf,
|
||||
Client: m.Client,
|
||||
MeshId: meshId,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = meshProvider.Load(meshBytes)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Meshes[meshId] = meshProvider
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddMeshNode: Add a mesh node
|
||||
func (m *MeshManager) AddMeshNode(meshId string, node MeshNode) {
|
||||
m.Meshes[meshId].AddNode(node)
|
||||
}
|
||||
|
||||
// HasChanges returns true if the mesh has changes
|
||||
func (m *MeshManager) HasChanges(meshId string) bool {
|
||||
return m.Meshes[meshId].HasChanges()
|
||||
}
|
||||
|
||||
// GetMesh returns the mesh with the given meshid
|
||||
func (m *MeshManager) GetMesh(meshId string) MeshProvider {
|
||||
theMesh, _ := m.Meshes[meshId]
|
||||
return theMesh
|
||||
}
|
||||
|
||||
// EnableInterface: Enables the given WireGuard interface.
|
||||
func (s *MeshManager) EnableInterface(meshId string) error {
|
||||
err := s.configApplyer.ApplyConfig()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.interfaceEnabler.EnableInterface(meshId)
|
||||
}
|
||||
|
||||
// GetPublicKey: Gets the public key of the WireGuard mesh
|
||||
func (s *MeshManager) GetPublicKey(meshId string) (*wgtypes.Key, error) {
|
||||
mesh, ok := s.Meshes[meshId]
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("mesh does not exist")
|
||||
}
|
||||
|
||||
dev, err := mesh.GetDevice()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &dev.PublicKey, nil
|
||||
}
|
||||
|
||||
func (s *MeshManager) GetSelf(meshId string) (MeshNode, error) {
|
||||
meshInstance, ok := s.Meshes[meshId]
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
|
||||
}
|
||||
|
||||
snapshot, err := meshInstance.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
|
||||
|
||||
if !ok {
|
||||
return nil, errors.New("the node doesn't exist in the mesh")
|
||||
}
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// UpdateTimeStamp updates the timestamp of this node in all meshes
|
||||
func (s *MeshManager) UpdateTimeStamp() error {
|
||||
for _, mesh := range s.Meshes {
|
||||
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Creates a new instance of a mesh manager with the given parameters
|
||||
func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client, meshProvider MeshProviderFactory) *MeshManager {
|
||||
hostParams := HostParameters{}
|
||||
|
||||
switch conf.PublicEndpoint {
|
||||
case "":
|
||||
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), conf.GrpcPort)
|
||||
default:
|
||||
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", conf.PublicEndpoint, conf.GrpcPort)
|
||||
}
|
||||
|
||||
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
|
||||
|
||||
m := &MeshManager{
|
||||
Meshes: make(map[string]MeshProvider),
|
||||
HostParameters: &hostParams,
|
||||
meshProviderFactory: meshProvider,
|
||||
Client: client,
|
||||
conf: &conf,
|
||||
}
|
||||
m.configApplyer = NewWgMeshConfigApplyer(m)
|
||||
m.RouteManager = NewRouteManager(m)
|
||||
m.interfaceEnabler = NewWgMeshInterface(m)
|
||||
return m
|
||||
}
|
@ -1,78 +0,0 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ip"
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
"github.com/tim-beatham/wgmesh/pkg/route"
|
||||
)
|
||||
|
||||
type RouteManager interface {
|
||||
UpdateRoutes() error
|
||||
ApplyWg(mesh *crdt.CrdtNodeManager) error
|
||||
}
|
||||
|
||||
type RouteManagerImpl struct {
|
||||
meshManager *MeshManger
|
||||
routeInstaller route.RouteInstaller
|
||||
}
|
||||
|
||||
func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||
meshes := r.meshManager.Meshes
|
||||
ulaBuilder := new(ip.ULABuilder)
|
||||
|
||||
for _, mesh1 := range meshes {
|
||||
for _, mesh2 := range meshes {
|
||||
if mesh1 == mesh2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ipNet, err := ulaBuilder.GetIPNet(mesh2.MeshId)
|
||||
|
||||
if err != nil {
|
||||
logging.Log.WriteErrorf(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
mesh1.AddRoutes(ipNet.String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteManagerImpl) ApplyWg(mesh *crdt.CrdtNodeManager) error {
|
||||
snapshot, err := mesh.GetCrdt()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, node := range snapshot.Nodes {
|
||||
if node.HostEndpoint == r.meshManager.HostEndpoint {
|
||||
continue
|
||||
}
|
||||
|
||||
for route, _ := range node.Routes {
|
||||
_, netIP, err := net.ParseCIDR(route)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = r.routeInstaller.InstallRoutes(mesh.IfName, netIP)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRouteManager(m *MeshManger) RouteManager {
|
||||
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
|
||||
}
|
73
pkg/mesh/routemanager.go
Normal file
73
pkg/mesh/routemanager.go
Normal file
@ -0,0 +1,73 @@
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"github.com/tim-beatham/wgmesh/pkg/route"
|
||||
)
|
||||
|
||||
type RouteManager interface {
|
||||
UpdateRoutes() error
|
||||
ApplyWg() error
|
||||
}
|
||||
|
||||
type RouteManagerImpl struct {
|
||||
meshManager *MeshManager
|
||||
routeInstaller route.RouteInstaller
|
||||
}
|
||||
|
||||
func (r *RouteManagerImpl) UpdateRoutes() error {
|
||||
// // meshes := r.meshManager.Meshes
|
||||
// // ulaBuilder := new(ip.ULABuilder)
|
||||
|
||||
// for _, mesh1 := range meshes {
|
||||
// for _, mesh2 := range meshes {
|
||||
// if mesh1 == mesh2 {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// ipNet, err := ulaBuilder.GetIPNet(mesh2.MeshId)
|
||||
|
||||
// if err != nil {
|
||||
// logging.Log.WriteErrorf(err.Error())
|
||||
// return err
|
||||
// }
|
||||
|
||||
// mesh1.AddRoutes(ipNet.String())
|
||||
// }
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RouteManagerImpl) ApplyWg() error {
|
||||
// snapshot, err := mesh.GetMesh()
|
||||
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// for _, node := range snapshot.Nodes {
|
||||
// if node.HostEndpoint == r.meshManager.HostEndpoint {
|
||||
// continue
|
||||
// }
|
||||
|
||||
// for route, _ := range node.Routes {
|
||||
// _, netIP, err := net.ParseCIDR(route)
|
||||
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// err = r.routeInstaller.InstallRoutes(mesh.IfName, netIP)
|
||||
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRouteManager(m *MeshManager) RouteManager {
|
||||
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
|
||||
}
|
84
pkg/mesh/types.go
Normal file
84
pkg/mesh/types.go
Normal file
@ -0,0 +1,84 @@
|
||||
// mesh provides implementation agnostic logic for managing
|
||||
// the mesh
|
||||
package mesh
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// MeshNode represents an implementation of a node in a mesh
|
||||
type MeshNode interface {
|
||||
// GetHostEndpoint: gets the gRPC endpoint of the node
|
||||
GetHostEndpoint() string
|
||||
// GetPublicKey: gets the public key of the node
|
||||
GetPublicKey() (wgtypes.Key, error)
|
||||
// GetWgEndpoint(): get IP and port of the wireguard endpoint
|
||||
GetWgEndpoint() string
|
||||
// GetWgHost: get the IP address of the WireGuard node
|
||||
GetWgHost() *net.IPNet
|
||||
// GetTimestamp: get the UNIX time stamp of the ndoe
|
||||
GetTimeStamp() int64
|
||||
// GetRoutes: returns the routes that the nodes provides
|
||||
GetRoutes() []string
|
||||
}
|
||||
|
||||
type MeshSnapshot interface {
|
||||
// GetNodes() returns the nodes in the mesh
|
||||
GetNodes() map[string]MeshNode
|
||||
}
|
||||
|
||||
// MeshSyncer syncs two meshes
|
||||
type MeshSyncer interface {
|
||||
GenerateMessage() ([]byte, bool)
|
||||
RecvMessage(mesg []byte) error
|
||||
Complete()
|
||||
}
|
||||
|
||||
// Mesh: Represents an implementation of a mesh
|
||||
type MeshProvider interface {
|
||||
// AddNode() adds a node to the mesh
|
||||
AddNode(node MeshNode)
|
||||
// GetMesh() returns a snapshot of the mesh provided by the mesh provider
|
||||
GetMesh() (MeshSnapshot, error)
|
||||
// GetMeshId() returns the ID of the mesh network
|
||||
GetMeshId() string
|
||||
// Save() saves the mesh network
|
||||
Save() []byte
|
||||
// Load() loads a mesh network
|
||||
Load([]byte) error
|
||||
// GetDevice() get the device corresponding with the mesh
|
||||
GetDevice() (*wgtypes.Device, error)
|
||||
// HasChanges returns true if we have changes since last time we synced
|
||||
HasChanges() bool
|
||||
// Record that we have changges and save the corresponding changes
|
||||
SaveChanges()
|
||||
// UpdateTimeStamp: update the timestamp of the given node
|
||||
UpdateTimeStamp(nodeId string) error
|
||||
// AddRoutes: adds routes to the given node
|
||||
AddRoutes(nodeId string, route ...string) error
|
||||
GetSyncer() MeshSyncer
|
||||
}
|
||||
|
||||
// HostParameters contains the IDs of a node
|
||||
type HostParameters struct {
|
||||
HostEndpoint string
|
||||
// TODO: Contain the WireGuard identifier in this
|
||||
}
|
||||
|
||||
// MeshProviderFactoryParams parameters required to build a mesh provider
|
||||
type MeshProviderFactoryParams struct {
|
||||
DevName string
|
||||
MeshId string
|
||||
Port int
|
||||
Conf *conf.WgMeshConfiguration
|
||||
Client *wgctrl.Client
|
||||
}
|
||||
|
||||
// MeshProviderFactory creates an instance of a mesh provider
|
||||
type MeshProviderFactory interface {
|
||||
CreateMesh(params *MeshProviderFactoryParams) (MeshProvider, error)
|
||||
}
|
@ -18,12 +18,12 @@ import (
|
||||
"github.com/tim-beatham/wgmesh/pkg/wg"
|
||||
)
|
||||
|
||||
type RobinIpc struct {
|
||||
type IpcHandler struct {
|
||||
Server *ctrlserver.MeshCtrlServer
|
||||
ipAllocator ip.IPAllocator
|
||||
}
|
||||
|
||||
func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
||||
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
||||
wg.CreateInterface(args.IfName)
|
||||
|
||||
meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort)
|
||||
@ -54,7 +54,7 @@ func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
||||
Routes: map[string]interface{}{},
|
||||
}
|
||||
|
||||
n.Server.MeshManager.AddMeshNode(meshId, meshNode)
|
||||
n.Server.MeshManager.AddMeshNode(meshId, &meshNode)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -64,12 +64,12 @@ func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
|
||||
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
|
||||
meshNames := make([]string, len(n.Server.MeshManager.Meshes))
|
||||
|
||||
i := 0
|
||||
for _, mesh := range n.Server.MeshManager.Meshes {
|
||||
meshNames[i] = mesh.MeshId
|
||||
for meshId, _ := range n.Server.MeshManager.Meshes {
|
||||
meshNames[i] = meshId
|
||||
i++
|
||||
}
|
||||
|
||||
@ -77,7 +77,7 @@ func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
||||
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
||||
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
|
||||
|
||||
client, err := peerConnection.GetClient()
|
||||
@ -130,33 +130,39 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
||||
WgHost: ipAddr.String() + "/128",
|
||||
Routes: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
n.Server.MeshManager.AddMeshNode(args.MeshId, node)
|
||||
n.Server.MeshManager.AddMeshNode(args.MeshId, &node)
|
||||
*reply = strconv.FormatBool(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
|
||||
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
|
||||
mesh := n.Server.MeshManager.GetMesh(meshId)
|
||||
meshSnapshot, err := mesh.GetCrdt()
|
||||
meshSnapshot, err := mesh.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if mesh != nil {
|
||||
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes))
|
||||
if mesh == nil {
|
||||
return errors.New("mesh does not exist")
|
||||
}
|
||||
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
|
||||
|
||||
i := 0
|
||||
for _, node := range meshSnapshot.Nodes {
|
||||
for _, node := range meshSnapshot.GetNodes() {
|
||||
pubKey, _ := node.GetPublicKey()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node := ctrlserver.MeshNode{
|
||||
HostEndpoint: node.HostEndpoint,
|
||||
WgEndpoint: node.WgEndpoint,
|
||||
PublicKey: node.PublicKey,
|
||||
WgHost: node.WgHost,
|
||||
Failed: mesh.HasFailed(node.HostEndpoint),
|
||||
Timestamp: node.Timestamp,
|
||||
Routes: lib.MapKeys(node.Routes),
|
||||
HostEndpoint: node.GetHostEndpoint(),
|
||||
WgEndpoint: node.GetWgEndpoint(),
|
||||
PublicKey: pubKey.String(),
|
||||
WgHost: node.GetWgHost().String(),
|
||||
Timestamp: node.GetTimeStamp(),
|
||||
Routes: node.GetRoutes(),
|
||||
}
|
||||
|
||||
nodes[i] = node
|
||||
@ -164,13 +170,11 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
|
||||
}
|
||||
|
||||
*reply = ipc.GetMeshReply{Nodes: nodes}
|
||||
} else {
|
||||
return errors.New("mesh does not exist")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *RobinIpc) EnableInterface(meshId string, reply *string) error {
|
||||
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
|
||||
err := n.Server.MeshManager.EnableInterface(meshId)
|
||||
|
||||
if err != nil {
|
||||
@ -182,7 +186,7 @@ func (n *RobinIpc) EnableInterface(meshId string, reply *string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *RobinIpc) GetDOT(meshId string, reply *string) error {
|
||||
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
|
||||
g := mesh.NewMeshDotConverter(n.Server.MeshManager)
|
||||
|
||||
result, err := g.Generate(meshId)
|
||||
@ -200,8 +204,8 @@ type RobinIpcParams struct {
|
||||
Allocator ip.IPAllocator
|
||||
}
|
||||
|
||||
func NewRobinIpc(ipcParams RobinIpcParams) RobinIpc {
|
||||
return RobinIpc{
|
||||
func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler {
|
||||
return IpcHandler{
|
||||
Server: ipcParams.CtrlServer,
|
||||
ipAllocator: ipcParams.Allocator,
|
||||
}
|
@ -8,7 +8,7 @@ import (
|
||||
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
||||
)
|
||||
|
||||
type RobinRpc struct {
|
||||
type WgRpc struct {
|
||||
rpc.UnimplementedMeshCtrlServerServer
|
||||
Server *ctrlserver.MeshCtrlServer
|
||||
}
|
||||
@ -36,7 +36,7 @@ func nodesToRpcNodes(nodes map[string]ctrlserver.MeshNode) []*rpc.MeshNode {
|
||||
return meshNodes
|
||||
}
|
||||
|
||||
func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
|
||||
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
|
||||
mesh := m.Server.MeshManager.GetMesh(request.MeshId)
|
||||
|
||||
if mesh == nil {
|
||||
@ -52,6 +52,6 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r
|
||||
return &reply, nil
|
||||
}
|
||||
|
||||
func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
|
||||
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
|
||||
return &rpc.JoinMeshReply{Success: true}, nil
|
||||
}
|
@ -1,9 +0,0 @@
|
||||
package rpc
|
||||
|
||||
import grpc "google.golang.org/grpc"
|
||||
|
||||
func NewRpcServer(rpcServer *grpc.Server, server MeshCtrlServerServer, auth AuthenticationServer) *grpc.Server {
|
||||
RegisterMeshCtrlServerServer(rpcServer, server)
|
||||
RegisterAuthenticationServer(rpcServer, auth)
|
||||
return rpcServer
|
||||
}
|
@ -18,13 +18,12 @@ type Syncer interface {
|
||||
}
|
||||
|
||||
type SyncerImpl struct {
|
||||
manager *mesh.MeshManger
|
||||
manager *mesh.MeshManager
|
||||
requester SyncRequester
|
||||
authenticatedNodes []crdt.MeshNodeCrdt
|
||||
}
|
||||
|
||||
const subSetLength = 5
|
||||
const maxAuthentications = 30
|
||||
const subSetLength = 3
|
||||
|
||||
// Sync: Sync random nodes
|
||||
func (s *SyncerImpl) Sync(meshId string) error {
|
||||
@ -39,27 +38,23 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
return errors.New("the provided mesh does not exist")
|
||||
}
|
||||
|
||||
snapshot, err := mesh.GetCrdt()
|
||||
snapshot, err := mesh.GetMesh()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(snapshot.Nodes) <= 1 {
|
||||
nodes := snapshot.GetNodes()
|
||||
|
||||
if len(nodes) <= 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
excludedNodes := map[string]struct{}{
|
||||
s.manager.HostEndpoint: {},
|
||||
s.manager.HostParameters.HostEndpoint: {},
|
||||
}
|
||||
|
||||
for _, node := range snapshot.Nodes {
|
||||
if mesh.HasFailed(node.HostEndpoint) {
|
||||
excludedNodes[node.HostEndpoint] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)
|
||||
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
|
||||
randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength)
|
||||
|
||||
before := time.Now()
|
||||
@ -71,7 +66,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
|
||||
syncMeshFunc := func() error {
|
||||
defer waitGroup.Done()
|
||||
err := s.requester.SyncMesh(meshId, n.HostEndpoint)
|
||||
err := s.requester.SyncMesh(meshId, n.GetHostEndpoint())
|
||||
return err
|
||||
}
|
||||
|
||||
@ -86,8 +81,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
|
||||
|
||||
// SyncMeshes: Sync all meshes
|
||||
func (s *SyncerImpl) SyncMeshes() error {
|
||||
for _, m := range s.manager.Meshes {
|
||||
err := s.Sync(m.MeshId)
|
||||
for meshId, _ := range s.manager.Meshes {
|
||||
err := s.Sync(meshId)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@ -97,6 +92,6 @@ func (s *SyncerImpl) SyncMeshes() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewSyncer(m *mesh.MeshManger, r SyncRequester) Syncer {
|
||||
func NewSyncer(m *mesh.MeshManager, r SyncRequester) Syncer {
|
||||
return &SyncerImpl{manager: m, requester: r}
|
||||
}
|
||||
|
@ -7,12 +7,14 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// SyncErrorHandler: Handles errors when attempting to sync
|
||||
type SyncErrorHandler interface {
|
||||
Handle(meshId string, endpoint string, err error) bool
|
||||
}
|
||||
|
||||
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
|
||||
type SyncErrorHandlerImpl struct {
|
||||
meshManager *mesh.MeshManger
|
||||
meshManager *mesh.MeshManager
|
||||
}
|
||||
|
||||
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
|
||||
@ -22,12 +24,6 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri
|
||||
return false
|
||||
}
|
||||
|
||||
err := mesh.IncrementFailedCount(endpoint)
|
||||
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@ -44,6 +40,6 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error)
|
||||
return false
|
||||
}
|
||||
|
||||
func NewSyncErrorHandler(m *mesh.MeshManger) SyncErrorHandler {
|
||||
func NewSyncErrorHandler(m *mesh.MeshManager) SyncErrorHandler {
|
||||
return &SyncErrorHandlerImpl{meshManager: m}
|
||||
}
|
||||
|
@ -6,9 +6,9 @@ import (
|
||||
"io"
|
||||
"time"
|
||||
|
||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
||||
)
|
||||
|
||||
@ -94,11 +94,10 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
||||
}
|
||||
|
||||
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
||||
mesh.DecrementFailedCount(endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncServiceClient) error {
|
||||
func syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
|
||||
stream, err := client.SyncMesh(ctx)
|
||||
|
||||
syncer := mesh.GetSyncer()
|
||||
@ -110,7 +109,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
|
||||
for {
|
||||
msg, moreMessages := syncer.GenerateMessage()
|
||||
|
||||
err := stream.Send(&rpc.SyncMeshRequest{MeshId: mesh.MeshId, Changes: msg})
|
||||
err := stream.Send(&rpc.SyncMeshRequest{MeshId: mesh.GetMeshId(), Changes: msg})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -6,8 +6,8 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
|
||||
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
|
||||
"github.com/tim-beatham/wgmesh/pkg/mesh"
|
||||
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
||||
)
|
||||
|
||||
@ -37,7 +37,7 @@ func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfR
|
||||
// SyncMesh: syncs the two streams changes
|
||||
func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error {
|
||||
var meshId = ""
|
||||
var syncer *crdt.AutomergeSync = nil
|
||||
var syncer mesh.MeshSyncer = nil
|
||||
|
||||
for {
|
||||
in, err := stream.Recv()
|
||||
|
@ -14,7 +14,7 @@ type TimestampScheduler interface {
|
||||
}
|
||||
|
||||
type TimeStampSchedulerImpl struct {
|
||||
meshManager *mesh.MeshManger
|
||||
meshManager *mesh.MeshManager
|
||||
updateRate int
|
||||
quit chan struct{}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user