Interfacing out components for unit testing

This commit is contained in:
Tim Beatham 2023-10-26 16:53:12 +01:00
parent f1cfd52a91
commit 4c6bbcffcd
28 changed files with 714 additions and 494 deletions

View File

@ -22,7 +22,7 @@ func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string {
WgPort: wgPort,
}
err := client.Call("RobinIpc.CreateMesh", &newMeshParams, &reply)
err := client.Call("IpcHandler.CreateMesh", &newMeshParams, &reply)
if err != nil {
return err.Error()
@ -34,7 +34,7 @@ func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string {
func listMeshes(client *ipcRpc.Client) {
reply := new(ipc.ListMeshReply)
err := client.Call("RobinIpc.ListMeshes", "", &reply)
err := client.Call("IpcHandler.ListMeshes", "", &reply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
@ -56,7 +56,7 @@ func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName str
Port: wgPort,
}
err := client.Call("RobinIpc.JoinMesh", &args, &reply)
err := client.Call("IpcHandler.JoinMesh", &args, &reply)
if err != nil {
return err.Error()
@ -68,7 +68,7 @@ func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName str
func getMesh(client *ipcRpc.Client, meshId string) {
reply := new(ipc.GetMeshReply)
err := client.Call("RobinIpc.GetMesh", &meshId, &reply)
err := client.Call("IpcHandler.GetMesh", &meshId, &reply)
if err != nil {
log.Panic(err.Error())
@ -92,7 +92,7 @@ func getMesh(client *ipcRpc.Client, meshId string) {
func enableInterface(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("RobinIpc.EnableInterface", &meshId, &reply)
err := client.Call("IpcHandler.EnableInterface", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())
@ -105,7 +105,7 @@ func enableInterface(client *ipcRpc.Client, meshId string) {
func getGraph(client *ipcRpc.Client, meshId string) {
var reply string
err := client.Call("RobinIpc.GetDOT", &meshId, &reply)
err := client.Call("IpcHandler.GetDOT", &meshId, &reply)
if err != nil {
fmt.Println(err.Error())

View File

@ -28,8 +28,8 @@ func main() {
return
}
var robinRpc robin.RobinRpc
var robinIpc robin.RobinIpc
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var authProvider middleware.AuthRpcProvider
var syncProvider sync.SyncServiceImpl

View File

@ -2,21 +2,22 @@ package crdt
import (
"errors"
"fmt"
"net"
"strings"
"time"
"github.com/automerge/automerge-go"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// CrdtNodeManager manages nodes in the crdt mesh
type CrdtNodeManager struct {
// CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct {
MeshId string
IfName string
NodeId string
@ -26,57 +27,63 @@ type CrdtNodeManager struct {
conf *conf.WgMeshConfiguration
}
const maxFails = 5
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
crdt, ok := node.(*MeshNodeCrdt)
if !ok {
panic("node must be of type *MeshNodeCrdt")
}
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
crdt.FailedMap = automerge.NewMap()
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint)
nodeVal.Map().Set("routes", automerge.NewMap())
}
func (c *CrdtNodeManager) ApplyWg() error {
snapshot, err := c.GetCrdt()
func (c *CrdtMeshManager) ApplyWg() error {
// snapshot, err := c.GetMesh()
if err != nil {
return err
}
// if err != nil {
// return err
// }
c.updateWgConf(c.IfName, snapshot.Nodes, *c.Client)
// c.updateWgConf(c.IfName, snapshot.GetNodes(), *c.Client)
return nil
}
// GetCrdt(): Converts the document into a struct
func (c *CrdtNodeManager) GetCrdt() (*MeshCrdt, error) {
// GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root())
}
// GetMeshId returns the meshid of the mesh
func (c *CrdtMeshManager) GetMeshId() string {
return c.MeshId
}
// Save: Save an entire mesh network
func (c *CrdtMeshManager) Save() []byte {
return c.doc.Save()
}
// Load: Load an entire mesh network
func (c *CrdtNodeManager) Load(bytes []byte) error {
func (c *CrdtMeshManager) Load(bytes []byte) error {
doc, err := automerge.Load(bytes)
if err != nil {
return err
}
c.doc = doc
return nil
}
// Save: Save an entire mesh network
func (c *CrdtNodeManager) Save() []byte {
return c.doc.Save()
}
// NewCrdtNodeManager: Create a new crdt node manager
func NewCrdtNodeManager(meshId, hostId, devName string, port int, conf conf.WgMeshConfiguration, client *wgctrl.Client) (*CrdtNodeManager, error) {
var manager CrdtNodeManager
func NewCrdtNodeManager(meshId, devName string, port int, conf conf.WgMeshConfiguration, client *wgctrl.Client) (*CrdtMeshManager, error) {
var manager CrdtMeshManager
manager.MeshId = meshId
manager.doc = automerge.New()
manager.IfName = devName
manager.Client = client
manager.NodeId = hostId
manager.conf = &conf
err := wg.CreateWgInterface(client, devName, port)
@ -88,7 +95,7 @@ func NewCrdtNodeManager(meshId, hostId, devName string, port int, conf conf.WgMe
return &manager, nil
}
func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) {
func (m *CrdtMeshManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) {
peerEndpoint, err := net.ResolveUDPAddr("udp", node.WgEndpoint)
if err != nil {
@ -125,45 +132,7 @@ func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfi
return &peerConfig, nil
}
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
func (c *CrdtNodeManager) changeFailedCount(meshId, endpoint string, incAmount int64) error {
node, err := c.doc.Path("nodes").Map().Get(endpoint)
if err != nil {
return err
}
counterMap, err := node.Map().Get("failedMap")
if counterMap.Kind() == automerge.KindVoid {
return errors.New("Something went wrong map does not exist")
}
counter, _ := counterMap.Map().Get(c.NodeId)
if counter.Kind() == automerge.KindVoid {
err = counterMap.Map().Set(c.NodeId, incAmount)
} else {
if counter.Int64()+incAmount < 0 {
return nil
}
err = counterMap.Map().Set(c.NodeId, counter.Int64()+1)
}
return err
}
// Increment failed count increments the number of times we have attempted
// to contact the node and it's failed
func (c *CrdtNodeManager) IncrementFailedCount(endpoint string) error {
return c.changeFailedCount(c.MeshId, endpoint, +1)
}
func (c *CrdtNodeManager) removeNode(endpoint string) error {
func (c *CrdtMeshManager) removeNode(endpoint string) error {
err := c.doc.Path("nodes").Map().Delete(endpoint)
if err != nil {
@ -173,14 +142,8 @@ func (c *CrdtNodeManager) removeNode(endpoint string) error {
return nil
}
// Decrement failed count decrements the number of times we have attempted to
// contact the node and it's failed
func (c *CrdtNodeManager) DecrementFailedCount(endpoint string) error {
return c.changeFailedCount(c.MeshId, endpoint, -1)
}
// GetNode: returns a mesh node crdt.
func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if err != nil {
@ -196,11 +159,11 @@ func (m *CrdtNodeManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
return meshNode, nil
}
func (m *CrdtNodeManager) Length() int {
func (m *CrdtMeshManager) Length() int {
return m.doc.Path("nodes").Map().Len()
}
func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) {
func (m *CrdtMeshManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
if err != nil {
@ -211,7 +174,7 @@ func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) {
}
// HasChanges returns true if we have changes since the last time we synced
func (m *CrdtNodeManager) HasChanges() bool {
func (m *CrdtMeshManager) HasChanges() bool {
changes, err := m.doc.Changes(m.LastHash)
logging.Log.WriteInfof("Changes %s", m.LastHash.String())
@ -224,34 +187,11 @@ func (m *CrdtNodeManager) HasChanges() bool {
return len(changes) > 0
}
func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
node, err := m.GetNode(endpoint)
if err != nil {
logging.Log.WriteErrorf("Cannot get node node: %s\n", endpoint)
return true
func (m *CrdtMeshManager) HasFailed(endpoint string) bool {
return false
}
values, err := node.FailedMap.Values()
if err != nil {
return true
}
countFailed := 0
for _, value := range values {
count := value.Int64()
if count >= 1 {
countFailed++
}
}
return countFailed >= 4
}
func (m *CrdtNodeManager) SaveChanges() {
func (m *CrdtMeshManager) SaveChanges() {
hashes := m.doc.Heads()
hash := hashes[len(hashes)-1]
@ -259,13 +199,17 @@ func (m *CrdtNodeManager) SaveChanges() {
m.LastHash = hash
}
func (m *CrdtNodeManager) UpdateTimeStamp() error {
node, err := m.doc.Path("nodes").Map().Get(m.NodeId)
func (m *CrdtMeshManager) UpdateTimeStamp(nodeId string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return errors.New("node is not a map")
}
err = node.Map().Set("timestamp", time.Now().Unix())
if err == nil {
@ -275,7 +219,32 @@ func (m *CrdtNodeManager) UpdateTimeStamp() error {
return err
}
func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return err
}
for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{})
if err != nil {
return err
}
}
return nil
}
func (m *CrdtMeshManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
var count int = 0
@ -300,35 +269,58 @@ func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNode
return nil
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtNodeManager) AddRoutes(routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(m.NodeId)
if err != nil {
return err
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return err
}
for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{})
if err != nil {
return err
}
}
return nil
}
func (m *CrdtNodeManager) GetSyncer() *AutomergeSync {
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (n *MeshNodeCrdt) GetEscapedIP() string {
return fmt.Sprintf("\"%s\"", n.WgHost)
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
func (m *MeshNodeCrdt) GetHostEndpoint() string {
return m.HostEndpoint
}
func (m *MeshNodeCrdt) GetPublicKey() (wgtypes.Key, error) {
return wgtypes.ParseKey(m.PublicKey)
}
func (m *MeshNodeCrdt) GetWgEndpoint() string {
return m.HostEndpoint
}
func (m *MeshNodeCrdt) GetWgHost() *net.IPNet {
_, ipnet, err := net.ParseCIDR(m.WgHost)
if err != nil {
logging.Log.WriteErrorf("Cannot parse WgHost %s", err.Error())
return nil
}
return ipnet
}
func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp
}
func (m *MeshNodeCrdt) GetRoutes() []string {
return lib.MapKeys(m.Routes)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
for _, node := range m.Nodes {
nodes[node.HostEndpoint] = &MeshNodeCrdt{
HostEndpoint: node.HostEndpoint,
WgEndpoint: node.WgEndpoint,
PublicKey: node.PublicKey,
WgHost: node.WgHost,
Timestamp: node.Timestamp,
Routes: node.Routes,
}
}
return nodes
}

View File

@ -7,7 +7,7 @@ import (
type AutomergeSync struct {
state *automerge.SyncState
manager *CrdtNodeManager
manager *CrdtMeshManager
}
func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
@ -35,7 +35,7 @@ func (a *AutomergeSync) Complete() {
a.manager.SaveChanges()
}
func NewAutomergeSync(manager *CrdtNodeManager) *AutomergeSync {
func NewAutomergeSync(manager *CrdtMeshManager) *AutomergeSync {
return &AutomergeSync{
state: automerge.NewSyncState(manager.doc),
manager: manager,

View File

@ -0,0 +1,10 @@
package crdt
import "github.com/tim-beatham/wgmesh/pkg/mesh"
type CrdtProviderFactory struct{}
func (f *CrdtProviderFactory) CreateMesh(params *mesh.MeshProviderFactoryParams) (mesh.MeshProvider, error) {
return NewCrdtNodeManager(params.MeshId, params.DevName, params.Port,
*params.Conf, params.Client)
}

View File

@ -1,7 +1,5 @@
package crdt
import "github.com/automerge/automerge-go"
// MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"`
@ -9,7 +7,6 @@ type MeshNodeCrdt struct {
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
FailedMap *automerge.Map `automerge:"failedMap"`
Routes map[string]interface{} `automerge:"routes"`
}

View File

@ -13,7 +13,11 @@ type WgMeshConfiguration struct {
PrivateKeyPath string `yaml:"privateKeyPath"`
SkipCertVerification bool `yaml:"skipCertVerification"`
GrpcPort string `yaml:"gRPCPort"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// PublicEndpoint is the IP in which this computer is publicly reachable.
// usecase is when the node is behind NAT.
PublicEndpoint string `yaml:"publicEndpoint"`
}
func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {

View File

@ -1,6 +1,7 @@
package ctrlserver
import (
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/conn"
"github.com/tim-beatham/wgmesh/pkg/mesh"
@ -21,7 +22,8 @@ type NewCtrlServerParams struct {
// operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer)
ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client)
factory := crdt.CrdtProviderFactory{}
ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client, &factory)
ctrlServer.Conf = params.Conf
connManagerParams := conn.NewConnectionManageParams{

View File

@ -16,7 +16,6 @@ type MeshNode struct {
WgEndpoint string
PublicKey string
WgHost string
Failed bool
Timestamp int64
Routes []string
}
@ -32,7 +31,7 @@ type Mesh struct {
*/
type MeshCtrlServer struct {
Client *wgctrl.Client
MeshManager *mesh.MeshManger
MeshManager *mesh.MeshManager
ConnectionManager conn.ConnectionManager
ConnectionServer *conn.ConnectionServer
Conf *conf.WgMeshConfiguration

View File

@ -1,15 +0,0 @@
/*
* RPC component of the server
*/
package ctrlserver
import (
"github.com/tim-beatham/wgmesh/pkg/rpc"
"google.golang.org/grpc"
)
func NewRpcServer(server rpc.MeshCtrlServerServer) *grpc.Server {
grpc := grpc.NewServer()
rpc.RegisterMeshCtrlServerServer(grpc, server)
return grpc
}

View File

@ -1,5 +1,9 @@
package lib
import (
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
// MapToSlice converts a map to a slice in go
func MapValues[K comparable, V any](m map[K]V) []V {
return MapValuesWithExclude(m, map[K]struct{}{})
@ -19,6 +23,8 @@ func MapValuesWithExclude[K comparable, V any](m map[K]V, exclude map[K]struct{}
continue
}
logging.Log.WriteInfof("Key %s", k)
values[i] = v
i++
}

View File

@ -5,14 +5,13 @@ import (
"net"
)
// GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
log.Fatal(err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
}

View File

@ -8,13 +8,14 @@ import (
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// MeshGraphConverter converts a mesh to a graph
type MeshGraphConverter interface {
// convert the mesh to textual form
Generate(meshId string) (string, error)
}
type MeshDOTConverter struct {
manager *MeshManger
manager *MeshManager
}
func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
@ -26,35 +27,33 @@ func (c *MeshDOTConverter) Generate(meshId string) (string, error) {
g := graph.NewGraph(meshId, graph.GRAPH)
snapshot, err := mesh.GetCrdt()
snapshot, err := mesh.GetMesh()
if err != nil {
return "", err
}
for _, node := range snapshot.Nodes {
g.AddNode(node.GetEscapedIP())
for _, node := range snapshot.GetNodes() {
g.AddNode(fmt.Sprintf("\"%s\"", node.GetWgHost().IP.String()))
}
nodes := lib.MapValues(snapshot.Nodes)
nodes := lib.MapValues(snapshot.GetNodes())
for i, node1 := range nodes[:len(nodes)-1] {
if mesh.HasFailed(node1.HostEndpoint) {
continue
}
for _, node2 := range nodes[i+1:] {
if node1.WgEndpoint == node2.WgEndpoint || mesh.HasFailed(node2.HostEndpoint) {
if node1.GetWgEndpoint() == node2.GetWgEndpoint() {
continue
}
g.AddEdge(fmt.Sprintf("%s to %s", node1.GetEscapedIP(), node2.GetEscapedIP()), node1.GetEscapedIP(), node2.GetEscapedIP())
node1Id := fmt.Sprintf("\"%s\"", node1.GetWgHost().IP.String())
node2Id := fmt.Sprintf("\"%s\"", node2.GetWgHost().IP.String())
g.AddEdge(fmt.Sprintf("%s to %s", node1Id, node2Id), node1Id, node2Id)
}
}
return g.GetDOT()
}
func NewMeshDotConverter(m *MeshManger) MeshGraphConverter {
func NewMeshDotConverter(m *MeshManager) MeshGraphConverter {
return &MeshDOTConverter{manager: m}
}

View File

@ -1,159 +0,0 @@
package mesh
import (
"errors"
"fmt"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshManger struct {
Meshes map[string]*crdt.CrdtNodeManager
RouteManager RouteManager
Client *wgctrl.Client
HostEndpoint string
conf *conf.WgMeshConfiguration
}
func (m *MeshManger) MeshExists(meshId string) bool {
_, inMesh := m.Meshes[meshId]
return inMesh
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManger) CreateMesh(devName string, port int) (string, error) {
key, err := wgtypes.GenerateKey()
if err != nil {
return "", err
}
nodeManager, err := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, port, *m.conf, m.Client)
if err != nil {
return "", err
}
m.Meshes[key.String()] = nodeManager
return key.String(), err
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManger) AddMesh(meshId string, devName string, port int, meshBytes []byte) error {
mesh, err := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, port, *m.conf, m.Client)
if err != nil {
return err
}
err = mesh.Load(meshBytes)
if err != nil {
return err
}
m.Meshes[meshId] = mesh
return nil
}
// AddMeshNode: Add a mesh node
func (m *MeshManger) AddMeshNode(meshId string, node crdt.MeshNodeCrdt) {
m.Meshes[meshId].AddNode(node)
if m.conf.AdvertiseRoutes {
m.RouteManager.UpdateRoutes()
}
}
func (m *MeshManger) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
}
func (m *MeshManger) GetMesh(meshId string) *crdt.CrdtNodeManager {
theMesh, _ := m.Meshes[meshId]
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManger) EnableInterface(meshId string) error {
mesh, contains := s.Meshes[meshId]
if !contains {
return errors.New("Mesh does not exist")
}
crdt, err := mesh.GetCrdt()
if err != nil {
return err
}
node, contains := crdt.Nodes[s.HostEndpoint]
if !contains {
return errors.New("Node does not exist in the mesh")
}
err = mesh.ApplyWg()
if err != nil {
return err
}
err = wg.EnableInterface(mesh.IfName, node.WgHost)
if s.conf.AdvertiseRoutes {
s.RouteManager.ApplyWg(mesh)
}
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
mesh, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManger) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
err := mesh.UpdateTimeStamp()
if err != nil {
return err
}
}
return nil
}
func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client) *MeshManger {
ip := lib.GetOutboundIP()
m := &MeshManger{
Meshes: make(map[string]*crdt.CrdtNodeManager),
HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort),
Client: client,
conf: &conf,
}
m.RouteManager = NewRouteManager(m)
return m
}

100
pkg/mesh/meshconfig.go Normal file
View File

@ -0,0 +1,100 @@
package mesh
import (
"net"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// MeshConfigApplyer abstracts applying the mesh configuration
type MeshConfigApplyer interface {
ApplyConfig() error
}
// WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager *MeshManager
}
func ConvertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
return nil, err
}
pubKey, err := node.GetPublicKey()
if err != nil {
return nil, err
}
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route)
allowedips = append(allowedips, *ipnet)
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
}
return &peerConfig, nil
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh()
if err != nil {
return err
}
nodes := snap.GetNodes()
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
var count int = 0
for _, n := range nodes {
peer, err := ConvertMeshNode(n)
if err != nil {
return err
}
peerConfigs[count] = *peer
count++
}
cfg := wgtypes.Config{
Peers: peerConfigs,
ReplacePeers: true,
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
return m.meshManager.Client.ConfigureDevice(dev.Name, cfg)
}
func (m *WgMeshConfigApplyer) ApplyConfig() error {
for _, mesh := range m.meshManager.Meshes {
err := m.updateWgConf(mesh)
if err != nil {
return err
}
}
return nil
}
func NewWgMeshConfigApplyer(manager *MeshManager) MeshConfigApplyer {
return &WgMeshConfigApplyer{meshManager: manager}
}

43
pkg/mesh/meshinterface.go Normal file
View File

@ -0,0 +1,43 @@
package mesh
import (
"errors"
"github.com/tim-beatham/wgmesh/pkg/wg"
)
// MeshInterfaces manipulates interfaces to do with meshes
type MeshInterface interface {
EnableInterface(meshId string) error
}
type WgMeshInterface struct {
manager *MeshManager
}
// EnableInterface enables the interface at the given endpoint
func (m *WgMeshInterface) EnableInterface(meshId string) error {
mesh, ok := m.manager.Meshes[meshId]
if !ok {
return errors.New("the provided mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
self, err := m.manager.GetSelf(meshId)
if err != nil {
return err
}
return wg.EnableInterface(dev.Name, self.GetWgHost().String())
}
func NewWgMeshInterface(manager *MeshManager) MeshInterface {
return &WgMeshInterface{manager: manager}
}

179
pkg/mesh/meshmanager.go Normal file
View File

@ -0,0 +1,179 @@
package mesh
import (
"errors"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshManager struct {
Meshes map[string]MeshProvider
RouteManager RouteManager
Client *wgctrl.Client
// HostParameters contains information that uniquely locates
// the node in the mesh network.
HostParameters *HostParameters
conf *conf.WgMeshConfiguration
meshProviderFactory MeshProviderFactory
configApplyer MeshConfigApplyer
interfaceEnabler MeshInterface
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManager) CreateMesh(devName string, port int) (string, error) {
key, err := wgtypes.GenerateKey()
if err != nil {
return "", err
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
Port: port,
Conf: m.conf,
Client: m.Client,
MeshId: key.String(),
})
if err != nil {
return "", err
}
m.Meshes[key.String()] = nodeManager
return key.String(), err
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManager) AddMesh(meshId string, devName string, port int, meshBytes []byte) error {
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
Port: port,
Conf: m.conf,
Client: m.Client,
MeshId: meshId,
})
if err != nil {
return err
}
err = meshProvider.Load(meshBytes)
if err != nil {
return err
}
m.Meshes[meshId] = meshProvider
return nil
}
// AddMeshNode: Add a mesh node
func (m *MeshManager) AddMeshNode(meshId string, node MeshNode) {
m.Meshes[meshId].AddNode(node)
}
// HasChanges returns true if the mesh has changes
func (m *MeshManager) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
}
// GetMesh returns the mesh with the given meshid
func (m *MeshManager) GetMesh(meshId string) MeshProvider {
theMesh, _ := m.Meshes[meshId]
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManager) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return s.interfaceEnabler.EnableInterface(meshId)
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManager) GetPublicKey(meshId string) (*wgtypes.Key, error) {
mesh, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New("mesh does not exist")
}
dev, err := mesh.GetDevice()
if err != nil {
return nil, err
}
return &dev.PublicKey, nil
}
func (s *MeshManager) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
}
snapshot, err := meshInstance.GetMesh()
if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh")
}
return node, nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManager) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil {
return err
}
}
return nil
}
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client, meshProvider MeshProviderFactory) *MeshManager {
hostParams := HostParameters{}
switch conf.PublicEndpoint {
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", conf.PublicEndpoint, conf.GrpcPort)
}
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManager{
Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,
meshProviderFactory: meshProvider,
Client: client,
conf: &conf,
}
m.configApplyer = NewWgMeshConfigApplyer(m)
m.RouteManager = NewRouteManager(m)
m.interfaceEnabler = NewWgMeshInterface(m)
return m
}

View File

@ -1,78 +0,0 @@
package mesh
import (
"net"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/ip"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
)
type RouteManager interface {
UpdateRoutes() error
ApplyWg(mesh *crdt.CrdtNodeManager) error
}
type RouteManagerImpl struct {
meshManager *MeshManger
routeInstaller route.RouteInstaller
}
func (r *RouteManagerImpl) UpdateRoutes() error {
meshes := r.meshManager.Meshes
ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes {
for _, mesh2 := range meshes {
if mesh1 == mesh2 {
continue
}
ipNet, err := ulaBuilder.GetIPNet(mesh2.MeshId)
if err != nil {
logging.Log.WriteErrorf(err.Error())
return err
}
mesh1.AddRoutes(ipNet.String())
}
}
return nil
}
func (r *RouteManagerImpl) ApplyWg(mesh *crdt.CrdtNodeManager) error {
snapshot, err := mesh.GetCrdt()
if err != nil {
return err
}
for _, node := range snapshot.Nodes {
if node.HostEndpoint == r.meshManager.HostEndpoint {
continue
}
for route, _ := range node.Routes {
_, netIP, err := net.ParseCIDR(route)
if err != nil {
return err
}
err = r.routeInstaller.InstallRoutes(mesh.IfName, netIP)
if err != nil {
return err
}
}
}
return nil
}
func NewRouteManager(m *MeshManger) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
}

73
pkg/mesh/routemanager.go Normal file
View File

@ -0,0 +1,73 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/route"
)
type RouteManager interface {
UpdateRoutes() error
ApplyWg() error
}
type RouteManagerImpl struct {
meshManager *MeshManager
routeInstaller route.RouteInstaller
}
func (r *RouteManagerImpl) UpdateRoutes() error {
// // meshes := r.meshManager.Meshes
// // ulaBuilder := new(ip.ULABuilder)
// for _, mesh1 := range meshes {
// for _, mesh2 := range meshes {
// if mesh1 == mesh2 {
// continue
// }
// ipNet, err := ulaBuilder.GetIPNet(mesh2.MeshId)
// if err != nil {
// logging.Log.WriteErrorf(err.Error())
// return err
// }
// mesh1.AddRoutes(ipNet.String())
// }
// }
return nil
}
func (r *RouteManagerImpl) ApplyWg() error {
// snapshot, err := mesh.GetMesh()
// if err != nil {
// return err
// }
// for _, node := range snapshot.Nodes {
// if node.HostEndpoint == r.meshManager.HostEndpoint {
// continue
// }
// for route, _ := range node.Routes {
// _, netIP, err := net.ParseCIDR(route)
// if err != nil {
// return err
// }
// err = r.routeInstaller.InstallRoutes(mesh.IfName, netIP)
// if err != nil {
// return err
// }
// }
// }
return nil
}
func NewRouteManager(m *MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
}

84
pkg/mesh/types.go Normal file
View File

@ -0,0 +1,84 @@
// mesh provides implementation agnostic logic for managing
// the mesh
package mesh
import (
"net"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// MeshNode represents an implementation of a node in a mesh
type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node
GetHostEndpoint() string
// GetPublicKey: gets the public key of the node
GetPublicKey() (wgtypes.Key, error)
// GetWgEndpoint(): get IP and port of the wireguard endpoint
GetWgEndpoint() string
// GetWgHost: get the IP address of the WireGuard node
GetWgHost() *net.IPNet
// GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides
GetRoutes() []string
}
type MeshSnapshot interface {
// GetNodes() returns the nodes in the mesh
GetNodes() map[string]MeshNode
}
// MeshSyncer syncs two meshes
type MeshSyncer interface {
GenerateMessage() ([]byte, bool)
RecvMessage(mesg []byte) error
Complete()
}
// Mesh: Represents an implementation of a mesh
type MeshProvider interface {
// AddNode() adds a node to the mesh
AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider
GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network
GetMeshId() string
// Save() saves the mesh network
Save() []byte
// Load() loads a mesh network
Load([]byte) error
// GetDevice() get the device corresponding with the mesh
GetDevice() (*wgtypes.Device, error)
// HasChanges returns true if we have changes since last time we synced
HasChanges() bool
// Record that we have changges and save the corresponding changes
SaveChanges()
// UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error
GetSyncer() MeshSyncer
}
// HostParameters contains the IDs of a node
type HostParameters struct {
HostEndpoint string
// TODO: Contain the WireGuard identifier in this
}
// MeshProviderFactoryParams parameters required to build a mesh provider
type MeshProviderFactoryParams struct {
DevName string
MeshId string
Port int
Conf *conf.WgMeshConfiguration
Client *wgctrl.Client
}
// MeshProviderFactory creates an instance of a mesh provider
type MeshProviderFactory interface {
CreateMesh(params *MeshProviderFactoryParams) (MeshProvider, error)
}

View File

@ -18,12 +18,12 @@ import (
"github.com/tim-beatham/wgmesh/pkg/wg"
)
type RobinIpc struct {
type IpcHandler struct {
Server *ctrlserver.MeshCtrlServer
ipAllocator ip.IPAllocator
}
func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
wg.CreateInterface(args.IfName)
meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort)
@ -54,7 +54,7 @@ func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
Routes: map[string]interface{}{},
}
n.Server.MeshManager.AddMeshNode(meshId, meshNode)
n.Server.MeshManager.AddMeshNode(meshId, &meshNode)
if err != nil {
return err
@ -64,12 +64,12 @@ func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
return nil
}
func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
func (n *IpcHandler) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
meshNames := make([]string, len(n.Server.MeshManager.Meshes))
i := 0
for _, mesh := range n.Server.MeshManager.Meshes {
meshNames[i] = mesh.MeshId
for meshId, _ := range n.Server.MeshManager.Meshes {
meshNames[i] = meshId
i++
}
@ -77,7 +77,7 @@ func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
return nil
}
func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
client, err := peerConnection.GetClient()
@ -130,33 +130,39 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
WgHost: ipAddr.String() + "/128",
Routes: make(map[string]interface{}),
}
n.Server.MeshManager.AddMeshNode(args.MeshId, node)
n.Server.MeshManager.AddMeshNode(args.MeshId, &node)
*reply = strconv.FormatBool(true)
return nil
}
func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.MeshManager.GetMesh(meshId)
meshSnapshot, err := mesh.GetCrdt()
meshSnapshot, err := mesh.GetMesh()
if err != nil {
return err
}
if mesh != nil {
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.Nodes))
if mesh == nil {
return errors.New("mesh does not exist")
}
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
i := 0
for _, node := range meshSnapshot.Nodes {
for _, node := range meshSnapshot.GetNodes() {
pubKey, _ := node.GetPublicKey()
if err != nil {
return err
}
node := ctrlserver.MeshNode{
HostEndpoint: node.HostEndpoint,
WgEndpoint: node.WgEndpoint,
PublicKey: node.PublicKey,
WgHost: node.WgHost,
Failed: mesh.HasFailed(node.HostEndpoint),
Timestamp: node.Timestamp,
Routes: lib.MapKeys(node.Routes),
HostEndpoint: node.GetHostEndpoint(),
WgEndpoint: node.GetWgEndpoint(),
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
}
nodes[i] = node
@ -164,13 +170,11 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
}
*reply = ipc.GetMeshReply{Nodes: nodes}
} else {
return errors.New("mesh does not exist")
}
return nil
}
func (n *RobinIpc) EnableInterface(meshId string, reply *string) error {
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.MeshManager.EnableInterface(meshId)
if err != nil {
@ -182,7 +186,7 @@ func (n *RobinIpc) EnableInterface(meshId string, reply *string) error {
return nil
}
func (n *RobinIpc) GetDOT(meshId string, reply *string) error {
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.MeshManager)
result, err := g.Generate(meshId)
@ -200,8 +204,8 @@ type RobinIpcParams struct {
Allocator ip.IPAllocator
}
func NewRobinIpc(ipcParams RobinIpcParams) RobinIpc {
return RobinIpc{
func NewRobinIpc(ipcParams RobinIpcParams) IpcHandler {
return IpcHandler{
Server: ipcParams.CtrlServer,
ipAllocator: ipcParams.Allocator,
}

View File

@ -8,7 +8,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
type RobinRpc struct {
type WgRpc struct {
rpc.UnimplementedMeshCtrlServerServer
Server *ctrlserver.MeshCtrlServer
}
@ -36,7 +36,7 @@ func nodesToRpcNodes(nodes map[string]ctrlserver.MeshNode) []*rpc.MeshNode {
return meshNodes
}
func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.GetMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId)
if mesh == nil {
@ -52,6 +52,6 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r
return &reply, nil
}
func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1,9 +0,0 @@
package rpc
import grpc "google.golang.org/grpc"
func NewRpcServer(rpcServer *grpc.Server, server MeshCtrlServerServer, auth AuthenticationServer) *grpc.Server {
RegisterMeshCtrlServerServer(rpcServer, server)
RegisterAuthenticationServer(rpcServer, auth)
return rpcServer
}

View File

@ -18,13 +18,12 @@ type Syncer interface {
}
type SyncerImpl struct {
manager *mesh.MeshManger
manager *mesh.MeshManager
requester SyncRequester
authenticatedNodes []crdt.MeshNodeCrdt
}
const subSetLength = 5
const maxAuthentications = 30
const subSetLength = 3
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
@ -39,27 +38,23 @@ func (s *SyncerImpl) Sync(meshId string) error {
return errors.New("the provided mesh does not exist")
}
snapshot, err := mesh.GetCrdt()
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
if len(snapshot.Nodes) <= 1 {
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
}
excludedNodes := map[string]struct{}{
s.manager.HostEndpoint: {},
s.manager.HostParameters.HostEndpoint: {},
}
for _, node := range snapshot.Nodes {
if mesh.HasFailed(node.HostEndpoint) {
excludedNodes[node.HostEndpoint] = struct{}{}
}
}
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength)
before := time.Now()
@ -71,7 +66,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
syncMeshFunc := func() error {
defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, n.HostEndpoint)
err := s.requester.SyncMesh(meshId, n.GetHostEndpoint())
return err
}
@ -86,8 +81,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for _, m := range s.manager.Meshes {
err := s.Sync(m.MeshId)
for meshId, _ := range s.manager.Meshes {
err := s.Sync(meshId)
if err != nil {
return err
@ -97,6 +92,6 @@ func (s *SyncerImpl) SyncMeshes() error {
return nil
}
func NewSyncer(m *mesh.MeshManger, r SyncRequester) Syncer {
func NewSyncer(m *mesh.MeshManager, r SyncRequester) Syncer {
return &SyncerImpl{manager: m, requester: r}
}

View File

@ -7,12 +7,14 @@ import (
"google.golang.org/grpc/status"
)
// SyncErrorHandler: Handles errors when attempting to sync
type SyncErrorHandler interface {
Handle(meshId string, endpoint string, err error) bool
}
// SyncErrorHandlerImpl Is an implementation of the SyncErrorHandler
type SyncErrorHandlerImpl struct {
meshManager *mesh.MeshManger
meshManager *mesh.MeshManager
}
func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint string) bool {
@ -22,12 +24,6 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri
return false
}
err := mesh.IncrementFailedCount(endpoint)
if err != nil {
return false
}
return true
}
@ -44,6 +40,6 @@ func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error)
return false
}
func NewSyncErrorHandler(m *mesh.MeshManger) SyncErrorHandler {
func NewSyncErrorHandler(m *mesh.MeshManager) SyncErrorHandler {
return &SyncErrorHandlerImpl{meshManager: m}
}

View File

@ -6,9 +6,9 @@ import (
"io"
"time"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
@ -94,11 +94,10 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
}
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
mesh.DecrementFailedCount(endpoint)
return nil
}
func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncServiceClient) error {
func syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
stream, err := client.SyncMesh(ctx)
syncer := mesh.GetSyncer()
@ -110,7 +109,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
for {
msg, moreMessages := syncer.GenerateMessage()
err := stream.Send(&rpc.SyncMeshRequest{MeshId: mesh.MeshId, Changes: msg})
err := stream.Send(&rpc.SyncMeshRequest{MeshId: mesh.GetMeshId(), Changes: msg})
if err != nil {
return err

View File

@ -6,8 +6,8 @@ import (
"errors"
"io"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
@ -37,7 +37,7 @@ func (s *SyncServiceImpl) GetConf(context context.Context, request *rpc.GetConfR
// SyncMesh: syncs the two streams changes
func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error {
var meshId = ""
var syncer *crdt.AutomergeSync = nil
var syncer mesh.MeshSyncer = nil
for {
in, err := stream.Recv()

View File

@ -14,7 +14,7 @@ type TimestampScheduler interface {
}
type TimeStampSchedulerImpl struct {
meshManager *mesh.MeshManger
meshManager *mesh.MeshManager
updateRate int
quit chan struct{}
}