Ability to be in multiple meshes and refactored consensus

This commit is contained in:
Tim Beatham 2023-10-24 16:00:46 +01:00
parent 8e89281484
commit 180f5e226c
19 changed files with 259 additions and 186 deletions

View File

@ -5,7 +5,7 @@ import (
"log" "log"
ipcRpc "net/rpc" ipcRpc "net/rpc"
"os" "os"
"strconv" "time"
"github.com/akamensky/argparse" "github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
@ -14,9 +14,14 @@ import (
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"
func createMesh(client *ipcRpc.Client) string { func createMesh(client *ipcRpc.Client, ifName string, wgPort int) string {
var reply string var reply string
err := client.Call("RobinIpc.CreateMesh", "", &reply) newMeshParams := ipc.NewMeshArgs{
IfName: ifName,
WgPort: wgPort,
}
err := client.Call("RobinIpc.CreateMesh", &newMeshParams, &reply)
if err != nil { if err != nil {
return err.Error() return err.Error()
@ -40,10 +45,15 @@ func listMeshes(client *ipcRpc.Client) {
} }
} }
func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string) string { func joinMesh(client *ipcRpc.Client, meshId string, ipAddress string, ifName string, wgPort int) string {
var reply string var reply string
args := ipc.JoinMeshArgs{MeshId: meshId, IpAdress: ipAddress} args := ipc.JoinMeshArgs{
MeshId: meshId,
IpAdress: ipAddress,
IfName: ifName,
Port: wgPort,
}
err := client.Call("RobinIpc.JoinMesh", &args, &reply) err := client.Call("RobinIpc.JoinMesh", &args, &reply)
@ -69,7 +79,7 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost) fmt.Println("Wg IP: " + node.WgHost)
fmt.Println("Failed Count: " + strconv.FormatBool(node.Failed)) fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()))
fmt.Println("---") fmt.Println("---")
} }
} }
@ -111,8 +121,14 @@ func main() {
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface") enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format") getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var meshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var meshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var ipAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var ipAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true}) var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -131,7 +147,7 @@ func main() {
} }
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(client)) fmt.Println(createMesh(client, *newMeshIfName, *newMeshPort))
} }
if listMeshCmd.Happened() { if listMeshCmd.Happened() {
@ -139,7 +155,7 @@ func main() {
} }
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(client, *meshId, *ipAddress)) fmt.Println(joinMesh(client, *meshId, *ipAddress, *joinMeshIfName, *joinMeshPort))
} }
if getMeshCmd.Happened() { if getMeshCmd.Happened() {

View File

@ -11,7 +11,8 @@ import (
"github.com/tim-beatham/wgmesh/pkg/middleware" "github.com/tim-beatham/wgmesh/pkg/middleware"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync" "github.com/tim-beatham/wgmesh/pkg/sync"
wg "github.com/tim-beatham/wgmesh/pkg/wg" "github.com/tim-beatham/wgmesh/pkg/timestamp"
"golang.zx2c4.com/wireguard/wgctrl"
) )
func main() { func main() {
@ -20,7 +21,12 @@ func main() {
log.Fatalln("Could not parse configuration") log.Fatalln("Could not parse configuration")
} }
wgClient, err := wg.CreateClient(conf.IfName, conf.WgPort) client, err := wgctrl.New()
if err != nil {
logging.Log.WriteErrorf("Failed to create wgctrl client")
return
}
var robinRpc robin.RobinRpc var robinRpc robin.RobinRpc
var robinIpc robin.RobinIpc var robinIpc robin.RobinIpc
@ -28,17 +34,18 @@ func main() {
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
ctrlServerParams := ctrlserver.NewCtrlServerParams{ ctrlServerParams := ctrlserver.NewCtrlServerParams{
WgClient: wgClient,
Conf: conf, Conf: conf,
AuthProvider: &authProvider, AuthProvider: &authProvider,
CtrlProvider: &robinRpc, CtrlProvider: &robinRpc,
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client,
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer, 60)
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -57,6 +64,7 @@ func main() {
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run() go syncScheduler.Run()
go timestampScheduler.Run()
err = ctrlServer.ConnectionServer.Listen() err = ctrlServer.ConnectionServer.Listen()
@ -67,6 +75,7 @@ func main() {
} }
defer syncScheduler.Stop() defer syncScheduler.Stop()
defer timestampScheduler.Stop()
defer ctrlServer.Close() defer ctrlServer.Close()
defer wgClient.Close() defer client.Close()
} }

View File

@ -5,26 +5,30 @@ import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"time"
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
// CrdtNodeManager manages nodes in the crdt mesh // CrdtNodeManager manages nodes in the crdt mesh
type CrdtNodeManager struct { type CrdtNodeManager struct {
MeshId string MeshId string
IfName string IfName string
NodeId string NodeId string
Client *wgctrl.Client Client *wgctrl.Client
doc *automerge.Doc doc *automerge.Doc
LastHash automerge.ChangeHash
} }
const maxFails = 5 const maxFails = 5
func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) { func (c *CrdtNodeManager) AddNode(crdt MeshNodeCrdt) {
crdt.FailedMap = automerge.NewMap() crdt.FailedMap = automerge.NewMap()
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
} }
@ -61,29 +65,22 @@ func (c *CrdtNodeManager) Save() []byte {
return c.doc.Save() return c.doc.Save()
} }
func (c *CrdtNodeManager) LoadChanges(changes []byte) error {
err := c.doc.LoadIncremental(changes)
if err != nil {
return err
}
return nil
}
func (c *CrdtNodeManager) SaveChanges() []byte {
return c.doc.SaveIncremental()
}
// NewCrdtNodeManager: Create a new crdt node manager // NewCrdtNodeManager: Create a new crdt node manager
func NewCrdtNodeManager(meshId, hostId, devName string, client *wgctrl.Client) *CrdtNodeManager { func NewCrdtNodeManager(meshId, hostId, devName string, port int, client *wgctrl.Client) (*CrdtNodeManager, error) {
var manager CrdtNodeManager var manager CrdtNodeManager
manager.MeshId = meshId manager.MeshId = meshId
manager.doc = automerge.New() manager.doc = automerge.New()
manager.IfName = devName manager.IfName = devName
manager.Client = client manager.Client = client
manager.NodeId = hostId manager.NodeId = hostId
return &manager
err := wg.CreateWgInterface(client, devName, port)
if err != nil {
return nil, err
}
return &manager, nil
} }
func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) { func (m *CrdtNodeManager) convertMeshNode(node MeshNodeCrdt) (*wgtypes.PeerConfig, error) {
@ -193,7 +190,29 @@ func (m *CrdtNodeManager) Length() int {
return m.doc.Path("nodes").Map().Len() return m.doc.Path("nodes").Map().Len()
} }
const thresholdVotes = 0.1 func (m *CrdtNodeManager) GetDevice() (*wgtypes.Device, error) {
dev, err := m.Client.Device(m.IfName)
if err != nil {
return nil, err
}
return dev, nil
}
// HasChanges returns true if we have changes since the last time we synced
func (m *CrdtNodeManager) HasChanges() bool {
changes, err := m.doc.Changes(m.LastHash)
logging.Log.WriteInfof("Changes %s", m.LastHash.String())
if err != nil {
return false
}
logging.Log.WriteInfof("Changes length %d", len(changes))
return len(changes) > 0
}
func (m *CrdtNodeManager) HasFailed(endpoint string) bool { func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
node, err := m.GetNode(endpoint) node, err := m.GetNode(endpoint)
@ -222,6 +241,30 @@ func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
return countFailed >= 4 return countFailed >= 4
} }
func (m *CrdtNodeManager) SaveChanges() {
hashes := m.doc.Heads()
hash := hashes[len(hashes)-1]
logging.Log.WriteInfof("Saved Hash %s", hash.String())
m.LastHash = hash
}
func (m *CrdtNodeManager) UpdateTimeStamp() error {
node, err := m.doc.Path("nodes").Map().Get(m.NodeId)
if err != nil {
return err
}
err = node.Map().Set("timestamp", time.Now().Unix())
if err == nil {
logging.Log.WriteInfof("Timestamp Updated for %s", m.MeshId)
}
return err
}
func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) peerConfigs := make([]wgtypes.PeerConfig, len(nodes))

View File

@ -2,10 +2,12 @@ package crdt
import ( import (
"github.com/automerge/automerge-go" "github.com/automerge/automerge-go"
logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
type AutomergeSync struct { type AutomergeSync struct {
state *automerge.SyncState state *automerge.SyncState
manager *CrdtNodeManager
} }
func (a *AutomergeSync) GenerateMessage() ([]byte, bool) { func (a *AutomergeSync) GenerateMessage() ([]byte, bool) {
@ -28,6 +30,14 @@ func (a *AutomergeSync) RecvMessage(msg []byte) error {
return nil return nil
} }
func NewAutomergeSync(manager *CrdtNodeManager) *AutomergeSync { func (a *AutomergeSync) Complete() {
return &AutomergeSync{state: automerge.NewSyncState(manager.doc)} logging.Log.WriteInfof("Sync Completed")
a.manager.SaveChanges()
}
func NewAutomergeSync(manager *CrdtNodeManager) *AutomergeSync {
return &AutomergeSync{
state: automerge.NewSyncState(manager.doc),
manager: manager,
}
} }

View File

@ -7,8 +7,8 @@ type MeshNodeCrdt struct {
WgEndpoint string `automerge:"wgEndpoint"` WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"` PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
FailedMap *automerge.Map `automerge:"failedMap"` FailedMap *automerge.Map `automerge:"failedMap"`
FailedInt int `automerge:"-"`
} }
type MeshCrdt struct { type MeshCrdt struct {

View File

@ -5,10 +5,12 @@ package conn
import ( import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"time"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive"
) )
// PeerConnection represents a client-side connection between two // PeerConnection represents a client-side connection between two
@ -42,7 +44,10 @@ func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnec
func (c *WgCtrlConnection) createGrpcConn() error { func (c *WgCtrlConnection) createGrpcConn() error {
conn, err := grpc.Dial(c.endpoint, conn, err := grpc.Dial(c.endpoint,
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)), grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)),
) grpc.WithKeepaliveParams(keepalive.ClientParameters{
Time: 10 * time.Minute,
Timeout: 30 * time.Minute,
}))
if err != nil { if err != nil {
logging.Log.WriteErrorf("Could not connect: %s\n", err.Error()) logging.Log.WriteErrorf("Could not connect: %s\n", err.Error())

View File

@ -10,8 +10,8 @@ import (
// NewCtrlServerParams are the params requried to create a new ctrl server // NewCtrlServerParams are the params requried to create a new ctrl server
type NewCtrlServerParams struct { type NewCtrlServerParams struct {
WgClient *wgctrl.Client
Conf *conf.WgMeshConfiguration Conf *conf.WgMeshConfiguration
Client *wgctrl.Client
AuthProvider rpc.AuthenticationServer AuthProvider rpc.AuthenticationServer
CtrlProvider rpc.MeshCtrlServerServer CtrlProvider rpc.MeshCtrlServerServer
SyncProvider rpc.SyncServiceServer SyncProvider rpc.SyncServiceServer
@ -21,8 +21,7 @@ type NewCtrlServerParams struct {
// operation failed // operation failed
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
ctrlServer := new(MeshCtrlServer) ctrlServer := new(MeshCtrlServer)
ctrlServer.Client = params.WgClient ctrlServer.MeshManager = mesh.NewMeshManager(*params.Conf, params.Client)
ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf)
ctrlServer.Conf = params.Conf ctrlServer.Conf = params.Conf
connManagerParams := conn.NewConnectionManageParams{ connManagerParams := conn.NewConnectionManageParams{

View File

@ -17,6 +17,7 @@ type MeshNode struct {
PublicKey string PublicKey string
WgHost string WgHost string
Failed bool Failed bool
Timestamp int64
} }
type Mesh struct { type Mesh struct {

View File

@ -10,9 +10,16 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
) )
type NewMeshArgs struct {
IfName string
WgPort int
}
type JoinMeshArgs struct { type JoinMeshArgs struct {
MeshId string MeshId string
IpAdress string IpAdress string
IfName string
Port int
} }
type GetMeshReply struct { type GetMeshReply struct {
@ -24,7 +31,7 @@ type ListMeshReply struct {
} }
type MeshIpc interface { type MeshIpc interface {
CreateMesh(name string, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error

View File

@ -16,6 +16,7 @@ type MeshManger struct {
Meshes map[string]*crdt.CrdtNodeManager Meshes map[string]*crdt.CrdtNodeManager
Client *wgctrl.Client Client *wgctrl.Client
HostEndpoint string HostEndpoint string
conf *conf.WgMeshConfiguration
} }
func (m *MeshManger) MeshExists(meshId string) bool { func (m *MeshManger) MeshExists(meshId string) bool {
@ -24,52 +25,32 @@ func (m *MeshManger) MeshExists(meshId string) bool {
} }
// CreateMesh: Creates a new mesh, stores it and returns the mesh id // CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManger) CreateMesh(devName string) (string, error) { func (m *MeshManger) CreateMesh(devName string, port int) (string, error) {
key, err := wgtypes.GenerateKey() key, err := wgtypes.GenerateKey()
if err != nil { if err != nil {
return "", err return "", err
} }
nodeManager := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, m.Client) nodeManager, err := crdt.NewCrdtNodeManager(key.String(), m.HostEndpoint, devName, port, m.Client)
if err != nil {
return "", err
}
m.Meshes[key.String()] = nodeManager m.Meshes[key.String()] = nodeManager
return key.String(), nil return key.String(), nil
} }
// UpdateMesh: merge the changes and save it to the device // AddMesh: Add the mesh to the list of meshes
func (m *MeshManger) UpdateMesh(meshId string, changes []byte) error { func (m *MeshManger) AddMesh(meshId string, devName string, port int, meshBytes []byte) error {
mesh, ok := m.Meshes[meshId] mesh, err := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, port, m.Client)
if !ok {
return errors.New("mesh does not exist")
}
err := mesh.LoadChanges(changes)
if err != nil { if err != nil {
return err return err
} }
return nil err = mesh.Load(meshBytes)
}
// ApplyWg: applies the wireguard configuration changes
func (m *MeshManger) ApplyWg() error {
for _, mesh := range m.Meshes {
err := mesh.ApplyWg()
if err != nil {
return err
}
}
return nil
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManger) AddMesh(meshId string, devName string, meshBytes []byte) error {
mesh := crdt.NewCrdtNodeManager(meshId, m.HostEndpoint, devName, m.Client)
err := mesh.Load(meshBytes)
if err != nil { if err != nil {
return err return err
@ -84,6 +65,10 @@ func (m *MeshManger) AddMeshNode(meshId string, node crdt.MeshNodeCrdt) {
m.Meshes[meshId].AddNode(node) m.Meshes[meshId].AddNode(node)
} }
func (m *MeshManger) HasChanges(meshId string) bool {
return m.Meshes[meshId].HasChanges()
}
func (m *MeshManger) GetMesh(meshId string) *crdt.CrdtNodeManager { func (m *MeshManger) GetMesh(meshId string) *crdt.CrdtNodeManager {
theMesh, _ := m.Meshes[meshId] theMesh, _ := m.Meshes[meshId]
return theMesh return theMesh
@ -109,6 +94,12 @@ func (s *MeshManger) EnableInterface(meshId string) error {
return errors.New("Node does not exist in the mesh") return errors.New("Node does not exist in the mesh")
} }
err = mesh.ApplyWg()
if err != nil {
return err
}
return wg.EnableInterface(mesh.IfName, node.WgHost) return wg.EnableInterface(mesh.IfName, node.WgHost)
} }
@ -120,7 +111,7 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
return nil, errors.New("mesh does not exist") return nil, errors.New("mesh does not exist")
} }
dev, err := s.Client.Device(mesh.IfName) dev, err := mesh.GetDevice()
if err != nil { if err != nil {
return nil, err return nil, err
@ -129,12 +120,26 @@ func (s *MeshManger) GetPublicKey(meshId string) (*wgtypes.Key, error) {
return &dev.PublicKey, nil return &dev.PublicKey, nil
} }
func NewMeshManager(client wgctrl.Client, conf conf.WgMeshConfiguration) *MeshManger { // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManger) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
err := mesh.UpdateTimeStamp()
if err != nil {
return err
}
}
return nil
}
func NewMeshManager(conf conf.WgMeshConfiguration, client *wgctrl.Client) *MeshManger {
ip := lib.GetOutboundIP() ip := lib.GetOutboundIP()
return &MeshManger{ return &MeshManger{
Meshes: make(map[string]*crdt.CrdtNodeManager), Meshes: make(map[string]*crdt.CrdtNodeManager),
Client: &client,
HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort), HostEndpoint: fmt.Sprintf("%s:%s", ip.String(), conf.GrpcPort),
Client: client,
conf: &conf,
} }
} }

View File

@ -3,6 +3,7 @@ package robin
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"strconv" "strconv"
"time" "time"
@ -22,10 +23,10 @@ type RobinIpc struct {
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
} }
func (n *RobinIpc) CreateMesh(name string, reply *string) error { func (n *RobinIpc) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
wg.CreateInterface(n.Server.Conf.IfName) wg.CreateInterface(args.IfName)
meshId, err := n.Server.MeshManager.CreateMesh(n.Server.Conf.IfName) meshId, err := n.Server.MeshManager.CreateMesh(args.IfName, args.WgPort)
if err != nil { if err != nil {
return err return err
@ -46,9 +47,9 @@ func (n *RobinIpc) CreateMesh(name string, reply *string) error {
outBoundIp := lib.GetOutboundIP() outBoundIp := lib.GetOutboundIP()
meshNode := crdt.MeshNodeCrdt{ meshNode := crdt.MeshNodeCrdt{
HostEndpoint: outBoundIp.String() + ":8080", HostEndpoint: fmt.Sprintf("%s:%s", outBoundIp.String(), n.Server.Conf.GrpcPort),
PublicKey: pubKey.String(), PublicKey: pubKey.String(),
WgEndpoint: outBoundIp.String() + ":51820", WgEndpoint: fmt.Sprintf("%s:%d", outBoundIp.String(), args.WgPort),
WgHost: nodeIP.String() + "/128", WgHost: nodeIP.String() + "/128",
} }
@ -99,7 +100,7 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err return err
} }
err = n.Server.MeshManager.AddMesh(args.MeshId, n.Server.Conf.IfName, meshReply.Mesh) err = n.Server.MeshManager.AddMesh(args.MeshId, args.IfName, args.Port, meshReply.Mesh)
if err != nil { if err != nil {
return err return err
@ -122,31 +123,14 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
outBoundIP := lib.GetOutboundIP() outBoundIP := lib.GetOutboundIP()
node := crdt.MeshNodeCrdt{ node := crdt.MeshNodeCrdt{
HostEndpoint: outBoundIP.String() + ":8080", HostEndpoint: fmt.Sprintf("%s:%s", outBoundIP.String(), n.Server.Conf.GrpcPort),
WgEndpoint: outBoundIP.String() + ":51820", WgEndpoint: fmt.Sprintf("%s:%d", outBoundIP.String(), args.Port),
PublicKey: pubKey.String(), PublicKey: pubKey.String(),
WgHost: ipAddr.String() + "/128", WgHost: ipAddr.String() + "/128",
} }
n.Server.MeshManager.AddMeshNode(args.MeshId, node) n.Server.MeshManager.AddMeshNode(args.MeshId, node)
mesh := n.Server.MeshManager.GetMesh(args.MeshId) *reply = strconv.FormatBool(true)
joinMeshRequest := rpc.JoinMeshRequest{
MeshId: args.MeshId,
Changes: mesh.SaveChanges(),
}
joinReply, err := c.JoinMesh(ctx, &joinMeshRequest)
if err != nil {
return err
}
if err != nil {
return err
}
*reply = strconv.FormatBool(joinReply.GetSuccess())
return nil return nil
} }
@ -169,6 +153,7 @@ func (n *RobinIpc) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
PublicKey: node.PublicKey, PublicKey: node.PublicKey,
WgHost: node.WgHost, WgHost: node.WgHost,
Failed: mesh.HasFailed(node.HostEndpoint), Failed: mesh.HasFailed(node.HostEndpoint),
Timestamp: node.Timestamp,
} }
nodes[i] = node nodes[i] = node

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
@ -54,19 +53,5 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r
} }
func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) { func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
mesh := m.Server.MeshManager.GetMesh(request.MeshId)
logging.Log.WriteInfof("[JOINING MESH]: " + request.MeshId)
if mesh == nil {
return nil, errors.New("mesh does not exist")
}
err := m.Server.MeshManager.UpdateMesh(request.MeshId, request.Changes)
if err != nil {
return nil, err
}
return &rpc.JoinMeshReply{Success: true}, nil return &rpc.JoinMeshReply{Success: true}, nil
} }

View File

@ -3,9 +3,11 @@ package sync
import ( import (
"errors" "errors"
"sync" "sync"
"time"
crdt "github.com/tim-beatham/wgmesh/pkg/automerge" crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
) )
@ -26,6 +28,11 @@ const maxAuthentications = 30
// Sync: Sync random nodes // Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(meshId string) error {
if !s.manager.HasChanges(meshId) {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
mesh := s.manager.GetMesh(meshId) mesh := s.manager.GetMesh(meshId)
if mesh == nil { if mesh == nil {
@ -55,6 +62,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes) meshNodes := lib.MapValuesWithExclude(snapshot.Nodes, excludedNodes)
randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength) randomSubset := lib.RandomSubsetOfLength(meshNodes, subSetLength)
before := time.Now()
var waitGroup sync.WaitGroup var waitGroup sync.WaitGroup
for _, n := range randomSubset { for _, n := range randomSubset {
@ -70,6 +79,8 @@ func (s *SyncerImpl) Sync(meshId string) error {
} }
waitGroup.Wait() waitGroup.Wait()
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before))
return nil return nil
} }
@ -83,7 +94,7 @@ func (s *SyncerImpl) SyncMeshes() error {
} }
} }
return s.manager.ApplyWg() return nil
} }
func NewSyncer(m *mesh.MeshManger, r SyncRequester) Syncer { func NewSyncer(m *mesh.MeshManger, r SyncRequester) Syncer {

View File

@ -14,7 +14,7 @@ import (
// SyncRequester: coordinates the syncing of meshes // SyncRequester: coordinates the syncing of meshes
type SyncRequester interface { type SyncRequester interface {
GetMesh(meshId string, endPoint string) error GetMesh(meshId string, ifName string, port int, endPoint string) error
SyncMesh(meshid string, endPoint string) error SyncMesh(meshid string, endPoint string) error
} }
@ -24,7 +24,7 @@ type SyncRequesterImpl struct {
} }
// GetMesh: Retrieves the local state of the mesh at the endpoint // GetMesh: Retrieves the local state of the mesh at the endpoint
func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endPoint string) error {
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint)
if err != nil { if err != nil {
@ -48,7 +48,7 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
return err return err
} }
err = s.server.MeshManager.AddMesh(meshId, s.server.Conf.IfName, reply.Mesh) err = s.server.MeshManager.AddMesh(meshId, ifName, port, reply.Mesh)
return err return err
} }
@ -137,7 +137,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
} }
} }
logging.Log.WriteInfof("SYNC finished") syncer.Complete()
stream.CloseSend() stream.CloseSend()
return nil return nil
} }

View File

@ -14,6 +14,7 @@ type SyncScheduler interface {
Stop() error Stop() error
} }
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct { type SyncSchedulerImpl struct {
syncRate int syncRate int
quit chan struct{} quit chan struct{}

View File

@ -8,7 +8,6 @@ import (
crdt "github.com/tim-beatham/wgmesh/pkg/automerge" crdt "github.com/tim-beatham/wgmesh/pkg/automerge"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
@ -41,11 +40,12 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
var syncer *crdt.AutomergeSync = nil var syncer *crdt.AutomergeSync = nil
for { for {
logging.Log.WriteInfof("Received Attempt")
in, err := stream.Recv() in, err := stream.Recv()
logging.Log.WriteInfof("Received Worked")
if err == io.EOF { if err == io.EOF {
if syncer != nil {
syncer.Complete()
}
return nil return nil
} }
@ -84,6 +84,9 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
} }
if !moreMessages || err == io.EOF { if !moreMessages || err == io.EOF {
if syncer != nil {
syncer.Complete()
}
return nil return nil
} }
} }

View File

@ -0,0 +1,48 @@
package timestamp
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
type TimestampScheduler interface {
Run() error
Stop() error
}
type TimeStampSchedulerImpl struct {
meshManager *mesh.MeshManger
updateRate int
quit chan struct{}
}
func (s *TimeStampSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.updateRate) * time.Second)
s.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := s.meshManager.UpdateTimeStamp()
if err != nil {
logging.Log.WriteErrorf("Update Timestamp Error: %s", err.Error())
}
case <-s.quit:
break
}
}
}
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer, updateRate int) TimestampScheduler {
return &TimeStampSchedulerImpl{meshManager: ctrlServer.MeshManager, updateRate: updateRate}
}
func (s *TimeStampSchedulerImpl) Stop() error {
close(s.quit)
return nil
}

View File

@ -30,23 +30,17 @@ func CreateInterface(ifName string) error {
/* /*
* Create and configure a new WireGuard client * Create and configure a new WireGuard client
*/ */
func CreateClient(ifName string, port int) (*wgctrl.Client, error) { func CreateWgInterface(client *wgctrl.Client, ifName string, port int) error {
err := CreateInterface(ifName) err := CreateInterface(ifName)
if err != nil { if err != nil {
return nil, err return err
}
client, err := wgctrl.New()
if err != nil {
return nil, err
} }
privateKey, err := wgtypes.GeneratePrivateKey() privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return nil, err return err
} }
var cfg wgtypes.Config = wgtypes.Config{ var cfg wgtypes.Config = wgtypes.Config{
@ -55,7 +49,7 @@ func CreateClient(ifName string, port int) (*wgctrl.Client, error) {
} }
client.ConfigureDevice(ifName, cfg) client.ConfigureDevice(ifName, cfg)
return client, nil return nil
} }
func EnableInterface(ifName string, ip string) error { func EnableInterface(ifName string, ip string) error {
@ -72,7 +66,7 @@ func EnableInterface(ifName string, ip string) error {
return err return err
} }
cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", "wgmesh") cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", ifName)
if err := cmd.Run(); err != nil { if err := cmd.Run(); err != nil {
return err return err

View File

@ -1,49 +0,0 @@
package main
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func main() {
client, err := wgctrl.New()
if err != nil {
return
}
privateKey, err := wgtypes.GeneratePrivateKey()
var listenPort int = 5109
if err != nil {
return
}
cfg := wgtypes.Config{
PrivateKey: &privateKey,
ListenPort: &listenPort,
}
err = client.ConfigureDevice("utun9", cfg)
if err != nil {
return
}
devices, err := client.Devices()
if err != nil {
return
}
fmt.Printf("Number of devices: %d\n", len(devices))
for _, device := range devices {
fmt.Printf("Device Name: %s\n", device.Name)
fmt.Printf("Listen Port: %d\n", device.ListenPort)
fmt.Printf("Private Key: %s\n", device.PrivateKey.String())
fmt.Printf("Public Key: %s\n", device.PublicKey.String())
}
}