1
0
forked from extern/smegmesh

Compare commits

...

27 Commits

Author SHA1 Message Date
d8e156f13f 36-add-route-path-into-route-object
Added the route path into the route object so that we can
see what meshes packets are routed across.
2023-11-27 18:55:41 +00:00
3fca49a1c9 Merge pull request #35 from tim-beatham/34-fix-routing
34 fix routing
2023-11-27 16:05:06 +00:00
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
45 changed files with 1421 additions and 742 deletions

View File

@ -8,4 +8,5 @@ RUN apt-get update && apt-get install -y \
tmux \
vim
WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./...

View File

@ -7,11 +7,13 @@ import (
)
func main() {
apiServer, err := api.NewSmegServer()
apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":40000")
apiServer.Run(":8080")
}

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

View File

@ -8,7 +8,9 @@ import (
"time"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
@ -16,7 +18,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
IfName string
WgPort int
Endpoint string
}
@ -24,7 +25,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
}
@ -68,7 +68,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort,
}
@ -96,9 +95,13 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()))
fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
advertiseRoutes := strings.Join(node.Routes, ",")
mapFunc := func(r ctrlserver.MeshRoute) string {
return r.Destination
}
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---")
@ -240,7 +243,6 @@ func main() {
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
@ -251,14 +253,12 @@ func main() {
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -298,7 +298,6 @@ func main() {
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
}))
@ -311,7 +310,6 @@ func main() {
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,

View File

@ -13,7 +13,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl"
)
@ -57,8 +57,9 @@ func main() {
syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer)
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
routeScheduler := timer.NewRouteScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -78,6 +79,7 @@ func main() {
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
go routeScheduler.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")

View File

@ -10,6 +10,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
@ -22,11 +23,36 @@ type ApiServer interface {
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
if route.Path == nil {
route.Path = make([]string, 0)
}
routes[index] = Route{
Prefix: route.Destination,
Path: route.Path,
}
}
return routes
}
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0)
meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
}
return &SmegNode{
@ -35,20 +61,20 @@ func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: meshNode.Routes,
Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey,
Alias: meshNode.Alias,
Alias: alias,
Services: meshNode.Services,
}
}
func meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *meshNodeToAPIMeshNode(node)
smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
}
return smegMesh
@ -62,11 +88,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort,
}
@ -100,7 +126,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort,
}
@ -139,7 +164,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
return
}
mesh := meshToAPIMesh(meshidParam, getMeshReply.Nodes)
mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
@ -168,7 +193,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
return
}
meshes = append(meshes, meshToAPIMesh(mesh, getMeshReply.Nodes))
meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
@ -179,13 +204,19 @@ func (s *SmegServer) Run(addr string) error {
return s.router.Run(addr)
}
func NewSmegServer() (ApiServer, error) {
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
@ -195,6 +226,7 @@ func NewSmegServer() (ApiServer, error) {
smegServer := &SmegServer{
router: router,
client: client,
words: words,
}
router.GET("/meshes", smegServer.GetMeshes)

View File

@ -1,5 +1,10 @@
package api
type Route struct {
Prefix string `json:"prefix"`
Path []string `json:"path"`
}
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
@ -8,7 +13,7 @@ type SmegNode struct {
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []string `json:"routes"`
Routes []Route `json:"routes"`
Services map[string]string `json:"services"`
}
@ -18,13 +23,15 @@ type SmegMesh struct {
}
type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}
type ApiServerConf struct {
WordsFile string
}

View File

@ -35,15 +35,55 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt")
}
crdt.Routes = make(map[string]interface{})
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
}
func (c *CrdtMeshManager) GetNodeIds() []string {
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
keepAliveTime := timestamp.Int64()
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys
}
@ -55,7 +95,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return nil, err
}
if c.cache == nil || len(changes) > 3 {
if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
@ -113,7 +153,7 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
// NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil
return node.Kind() == automerge.KindMap && err == nil
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
@ -279,7 +319,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -298,7 +338,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
}
for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{})
err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
if err != nil {
return err
@ -307,6 +350,67 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return nil
}
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.Path, m.GetMeshId()),
}
}
}
}
return routes, nil
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -419,8 +523,13 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp
}
func (m *MeshNodeCrdt) GetRoutes() []string {
return lib.MapKeys(m.Routes)
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
Destination: r.Destination,
Path: r.Path,
}
})
}
func (m *MeshNodeCrdt) GetDescription() string {
@ -450,6 +559,12 @@ func (m *MeshNodeCrdt) GetServices() map[string]string {
return services
}
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
@ -464,8 +579,22 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
Type: node.Type,
}
}
return nodes
}
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
func (r *Route) GetHopCount() int {
return len(r.Path)
}
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -28,16 +28,23 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort),
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty.
// Routes handled by external component
Routes: map[string]interface{}{},
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(f.Config.Role),
}
}
@ -50,7 +57,19 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else {
hostName = lib.GetOutboundIP().String()
ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName

View File

@ -1,16 +1,23 @@
package crdt
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
Path []string `automerge:"path"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]Route `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
}
// MeshCrdt: Represents the mesh network as a whole

View File

@ -16,6 +16,20 @@ func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
)
type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"`
@ -28,6 +42,9 @@ type WgMeshConfiguration struct {
SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
// use public IP discovery library
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable.
@ -47,12 +64,21 @@ type WgMeshConfiguration struct {
KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead
// PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
}
func ValidateConfiguration(c *WgMeshConfiguration) error {
@ -122,9 +148,15 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
}
}
if c.PruneTime <= 1 {
if c.PruneTime < 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1",
msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
}
}
@ -134,6 +166,14 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
}
}
if c.Role == "" {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil
}

View File

@ -31,11 +31,11 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := crdt.MeshNodeFactory{
Config: *params.Conf,
}
idGenerator := &lib.UUIDGenerator{}
idGenerator := &lib.IDNameGenerator{}
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer()
configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,

View File

@ -9,6 +9,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshRoute struct {
Destination string
Path []string
}
// Represents a WireGuard MeshNode
type MeshNode struct {
HostEndpoint string
@ -16,7 +21,7 @@ type MeshNode struct {
PublicKey string
WgHost string
Timestamp int64
Routes []string
Routes []MeshRoute
Description string
Alias string
Services map[string]string

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

View File

@ -4,13 +4,13 @@ package rpctypes;
option go_package = "pkg/rpc";
service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {}
rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
}
message JoinMeshRequest {
string meshId = 2;
message GetMeshRequest {
string meshId = 1;
}
message JoinMeshReply {
bool success = 1;
message GetMeshReply {
bytes mesh = 1;
}

View File

@ -11,8 +11,6 @@ import (
)
type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose
WgPort int
// Endpoint is the routable alias of the machine. Can be an IP
@ -25,13 +23,14 @@ type JoinMeshArgs struct {
MeshId string
// IpAddress is a routable IP in another mesh
IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose
Port int
// Endpoint is the routable address of this machine. If not provided
// defaults to the default address
Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
}
type PutServiceArgs struct {
@ -63,7 +62,6 @@ type MeshIpc interface {
JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error

View File

@ -66,3 +66,13 @@ func Filter[V any](list []V, f filterFunc[V]) []V {
return newList
}
func Contains[V any](list []V, proposition func(V) bool) bool {
for _, elem := range list {
if proposition(elem) {
return true
}
}
return false
}

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
bucketFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -1,6 +1,9 @@
package lib
import "github.com/google/uuid"
import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
)
// IdGenerator generates unique ids
type IdGenerator interface {
@ -15,3 +18,11 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New()
return id.String(), nil
}
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -1,17 +1,61 @@
package lib
import (
"encoding/json"
"io"
"log"
"net"
"net/http"
)
// GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP {
func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
log.Fatal(err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
}

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

View File

@ -201,7 +201,7 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
})
if err != nil {
return fmt.Errorf("failed to delete route %w", err)
return fmt.Errorf("failed to delete route %s", dst.IP.String())
}
return nil
@ -219,22 +219,15 @@ func (r1 Route) equal(r2 Route) bool {
// DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0)
routes, err := c.listRoutes(ifName, family)
if len(exclude) != 0 {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway)
if err != nil {
return err
}
routes = lRoutes
if err != nil {
return err
}
ifRoutes := make([]Route, 0)
for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128
if family == unix.AF_INET {
@ -262,7 +255,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String())
logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
err := c.DeleteRoute(ifName, route)
if err != nil {
@ -274,7 +267,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
}
// listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) {
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName)
if err != nil {
@ -288,7 +281,7 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP
}
filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index)
return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
}
routes = Filter(routes, filterFunc)

View File

@ -3,7 +3,13 @@ package mesh
import (
"fmt"
"net"
"slices"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -16,10 +22,24 @@ type MeshConfigApplyer interface {
// WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager MeshManager
meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
}
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
type routeNode struct {
gateway string
route Route
}
func (r *routeNode) equals(route2 *routeNode) bool {
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
@ -35,20 +55,109 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
clients, ok := peerToClients[node.GetWgHost().String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route)
allowedips = append(allowedips, *ipnet)
bestRoutes := routes[route.GetDestination().String()]
if len(bestRoutes) == 1 {
allowedips = append(allowedips, *route.GetDestination())
} else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *route.GetDestination())
}
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
if existing != -1 {
endpoint = device.Peers[existing].Endpoint
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
}
return &peerConfig, nil
}
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode)
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey()
meshRoutes, _ := meshProvider.GetRoutes(pubKey.String())
for _, route := range meshRoutes {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix == nil || route == nil || route.GetDestination() == nil {
return false
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
}
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh()
@ -56,18 +165,67 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
nodes := snap.GetNodes()
nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
peerToClients := make(map[string][]net.IPNet)
routes := m.getRoutes(mesh)
installedRoutes := make([]lib.Route, 0)
for _, n := range nodes {
peer, err := convertMeshNode(n)
if NodeEquals(n, self) {
continue
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
hashFunc := func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
}
peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc)
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
dev, _ := mesh.GetDevice()
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil {
return err
}
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
peerConfigs[count] = *peer
count++
}
@ -82,6 +240,12 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
if err != nil {
return err
}
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
}
@ -112,7 +276,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1),
Peers: make([]wgtypes.PeerConfig, 0),
})
return nil
@ -122,6 +286,9 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager
}
func NewWgMeshConfigApplyer() MeshConfigApplyer {
return &WgMeshConfigApplyer{}
func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
}
}

View File

@ -61,7 +61,7 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId strin
self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() {
if NodeEquals(self, node) {
return
}

View File

@ -14,11 +14,10 @@ import (
)
type MeshManager interface {
CreateMesh(devName string, port int) (string, error)
CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error
GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error
@ -35,6 +34,7 @@ type MeshManager interface {
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
GetRouteManager() RouteManager
}
type MeshManagerImpl struct {
@ -54,10 +54,15 @@ type MeshManagerImpl struct {
Monitor MeshMonitor
}
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service)
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil {
return err
@ -70,7 +75,7 @@ func (m *MeshManagerImpl) RemoveService(service string) error {
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value)
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil {
return err
@ -115,15 +120,25 @@ func (m *MeshManagerImpl) Prune() error {
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil {
return "", err
}
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
DevName: ifName,
Port: port,
Conf: m.conf,
Client: m.Client,
@ -134,32 +149,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err)
}
if !m.conf.StubWg {
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
m.Meshes[meshId] = nodeManager
return meshId, nil
}
type AddMeshParams struct {
MeshId string
DevName string
WgPort int
MeshBytes []byte
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName,
DevName: ifName,
Port: params.WgPort,
Conf: m.conf,
Client: m.Client,
@ -177,14 +191,6 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
}
m.Meshes[params.MeshId] = meshProvider
if !m.conf.StubWg {
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
}
return nil
}
@ -199,23 +205,6 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
@ -255,20 +244,26 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
}
pubKey, err := s.GetPublicKey(params.MeshId)
if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
if err != nil {
return err
}
params.WgPort = device.ListenPort
}
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId)
pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
if err != nil {
return err
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey,
PublicKey: &pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
@ -300,22 +295,23 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return fmt.Errorf("mesh %s does not exist", meshId)
}
err := s.RouteManager.RemoveRoutes(meshId)
if err != nil {
return err
}
var err error
if !s.conf.StubWg {
device, e := mesh.GetDevice()
device, err := mesh.GetDevice()
if e != nil {
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
}
err = s.RouteManager.RemoveRoutes(meshId)
delete(s.Meshes, meshId)
return err
}
@ -327,7 +323,8 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint)
logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil {
return nil, errors.New("the node doesn't exist in the mesh")
@ -337,6 +334,10 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
}
func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
err := s.configApplyer.ApplyConfig()
if err != nil {
@ -348,8 +349,8 @@ func (s *MeshManagerImpl) ApplyConfig() error {
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil {
return err
@ -363,8 +364,8 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias)
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil {
return err
@ -377,8 +378,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil {
return err
@ -435,17 +436,11 @@ type NewMeshManagerParams struct {
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{}
switch params.Conf.Endpoint {
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
PrivateKey: &privateKey,
}
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,

View File

@ -64,7 +64,6 @@ func TestAddMeshAddsAMesh(t *testing.T) {
manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -83,7 +82,6 @@ func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -106,7 +104,6 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -175,7 +172,6 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -201,8 +197,8 @@ func TestSetDescription(t *testing.T) {
manager := getMeshManager()
description := "wooooo"
meshId1, _ := manager.CreateMesh("wg0", 5000)
meshId2, _ := manager.CreateMesh("wg0", 5001)
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
@ -225,8 +221,8 @@ func TestSetDescription(t *testing.T) {
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh("wg0", 5000)
meshId2, _ := manager.CreateMesh("wg0", 5001)
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,

View File

@ -1,25 +1,18 @@
package mesh
import (
"fmt"
"net"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix"
)
type RouteManager interface {
UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error
}
type RouteManagerImpl struct {
meshManager MeshManager
routeInstaller route.RouteInstaller
meshManager MeshManager
}
func (r *RouteManagerImpl) UpdateRoutes() error {
@ -27,6 +20,24 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
pubKey, err := self.GetPublicKey()
if err != nil {
return err
}
routes, err := mesh1.GetRoutes(pubKey.String())
if err != nil {
return err
}
for _, mesh2 := range meshes {
if mesh1 == mesh2 {
continue
@ -39,13 +50,11 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String())
err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
Destination: ipNet,
HopCount: 0,
Path: make([]string, 0),
})...)
if err != nil {
return err
@ -74,111 +83,11 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
return err
}
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String())
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
}
return nil
}
// AddRoute adds a route to the given interface
func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
// Delete any routes that may be vacant
err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
if route.Destination.String() == meshPrefix {
continue
}
err = rtnl.AddRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error {
routeMapFunc := func(route string) lib.Route {
_, cidr, _ := net.ParseCIDR(route)
r := lib.Route{
Destination: *cidr,
Gateway: node.GetWgHost().IP,
}
return r
}
ipBuilder := &ip.ULABuilder{}
ipNet, err := ipBuilder.GetIPNet(meshid)
if err != nil {
return err
}
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...)
}
func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
mesh, err := meshProvider.GetMesh()
if err != nil {
return err
}
dev, err := meshProvider.GetDevice()
if err != nil {
return err
}
self, err := m.meshManager.GetSelf(meshProvider.GetMeshId())
if err != nil {
return err
}
for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() {
continue
}
err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node)
if err != nil {
return err
}
}
return nil
}
// InstallRoutes installs all routes to the RIB
func (r *RouteManagerImpl) InstallRoutes() error {
for _, mesh := range r.meshManager.GetMeshes() {
err := r.installRoutes(mesh)
if err != nil {
return err
}
}
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
return &RouteManagerImpl{meshManager: m}
}

View File

@ -16,11 +16,16 @@ type MeshNodeStub struct {
wgEndpoint string
wgHost *net.IPNet
timeStamp int64
routes []string
routes []Route
identifier string
description string
}
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return conf.PEER_ROLE
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
@ -51,7 +56,7 @@ func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp
}
func (m *MeshNodeStub) GetRoutes() []string {
func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes
}
@ -76,29 +81,33 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string {
panic("unimplemented")
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented")
return nil, nil
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented")
return false
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented")
return nil
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented")
return nil
}
// SetAlias implements MeshProvider.
@ -108,7 +117,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented")
return nil
}
// Prune implements MeshProvider.
@ -154,7 +163,7 @@ func (s *MeshProviderStub) HasChanges() bool {
return false
}
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error {
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil
}
@ -188,7 +197,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost,
timeStamp: time.Now().Unix(),
routes: make([]string, 0),
routes: make([]Route, 0),
identifier: "abc",
description: "A Mesh Node Stub",
}
@ -211,6 +220,11 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider
}
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
@ -250,7 +264,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
}
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil
}
@ -273,10 +287,6 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
}
func (m *MeshManagerStub) EnableInterface(meshId string) error {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey()
return &key, nil

View File

@ -4,13 +4,39 @@ package mesh
import (
"net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
// GetPath: get a list of AS paths to get to the destination
GetPath() []string
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
Path []string
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
}
func (r *RouteStub) GetPath() []string {
return r.Path
}
// MeshNode represents an implementation of a node in a mesh
type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node
@ -24,7 +50,7 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides
GetRoutes() []string
GetRoutes() []Route
// GetIdentifier: returns the identifier of the node
GetIdentifier() string
// GetDescription: returns the description for this node
@ -34,46 +60,25 @@ type MeshNode interface {
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
GetType() conf.NodeType
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() {
return false
}
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
node1Pub, _ := node1.GetPublicKey()
node2Pub, _ := node2.GetPublicKey()
return key1.String() == key2.String()
}
if node1Pub != node2Pub {
return false
}
func RouteEquals(route1, route2 Route) bool {
return route1.GetDestination().String() == route2.GetDestination().String() &&
route1.GetHopCount() == route2.GetHopCount()
}
if node1.GetWgEndpoint() != node2.GetWgEndpoint() {
return false
}
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()
}
type MeshSnapshot interface {
@ -109,7 +114,7 @@ type MeshProvider interface {
// UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error
AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync
@ -129,12 +134,20 @@ type MeshProvider interface {
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
GetNodeIds() []string
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
}
// HostParameters contains the IDs of a node
type HostParameters struct {
HostEndpoint string
PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
}
// MeshProviderFactoryParams parameters required to build a mesh provider

View File

@ -3,8 +3,10 @@ package query
import (
"encoding/json"
"fmt"
"strings"
"github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
@ -23,16 +25,23 @@ type QueryError struct {
msg string
}
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Timestamp int64 `json:"timestamp"`
Description string `json:"description"`
Routes []string `json:"routes"`
Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
}
func (m *QueryError) Error() string {
@ -76,10 +85,17 @@ func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes()
queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
Path: strings.Join(r.GetPath(), ","),
}
})
queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode
}

View File

@ -10,6 +10,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
@ -20,7 +21,7 @@ type IpcHandler struct {
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort)
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil {
return err
@ -72,7 +73,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -83,7 +86,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port,
MeshBytes: meshReply.Mesh,
})
@ -118,19 +120,19 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
}
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId)
theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil {
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := mesh.GetMesh()
meshSnapshot, err := theMesh.GetMesh()
if err != nil {
return err
}
if mesh == nil {
if theMesh == nil {
return errors.New("mesh does not exist")
}
@ -150,10 +152,15 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute {
return ctrlserver.MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
@ -164,18 +171,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil
}
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())

View File

@ -28,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil
}
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1,22 +1,32 @@
package route
import (
"net"
"os/exec"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
"golang.org/x/sys/unix"
)
type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error
InstallRoutes(devName string, routes ...lib.Route) error
}
type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error {
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
err := r.installRoute(devName, route)
err := rtnl.AddRoute(devName, route)
if err != nil {
return err
@ -26,22 +36,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil
}
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{}
}

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil
}
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 5,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
}
type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil
}
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler)
}
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -44,15 +44,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
}
}
nodeNames := s.manager.GetMesh(meshId).GetNodeIds()
nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
selfPublickey, err := self.GetPublicKey()
if err != nil {
return err
}
neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
@ -63,7 +68,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String())
randomSubset = append(randomSubset, interCluster)
}
@ -74,7 +79,14 @@ func (s *SyncerImpl) Sync(meshId string) error {
go func(i int) error {
defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, randomSubset[i])
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
}
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
return err
}(index)
}

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
DevName: ifName,
WgPort: port,
MeshBytes: reply.Mesh,
})

View File

@ -5,20 +5,6 @@ import (
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {

View File

@ -64,11 +64,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer()
} else if meshId != in.MeshId {
return errors.New("Differing MeshIDs")
return errors.New("differing meshids")
}
if syncer == nil {
return errors.New("Syncer should not be nil")
return errors.New("syncer should not be nil")
}
msg, moreMessages := syncer.GenerateMessage()

View File

@ -1,4 +1,4 @@
package timestamp
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
@ -12,3 +12,11 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}
func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes()
}
return *lib.NewTimer(timerFunc, 10)
}

View File

@ -2,8 +2,8 @@ package wg
type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error {
return nil
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return "", nil
}
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {

View File

@ -1,5 +1,7 @@
package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct {
msg string
}
@ -8,14 +10,9 @@ func (m *WgError) Error() string {
return m.msg
}
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error
CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface

View File

@ -1,6 +1,9 @@
package wg
import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt"
"github.com/tim-beatham/wgmesh/pkg/lib"
@ -13,40 +16,48 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client
}
const hashLength = 6
// CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error {
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to access link: %w", err)
return "", fmt.Errorf("failed to access link: %w", err)
}
defer rtnl.Close()
err = rtnl.CreateLink(params.IfName)
randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil {
return fmt.Errorf("failed to create link: %w", err)
return "", err
}
privateKey, err := wgtypes.GeneratePrivateKey()
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil {
return fmt.Errorf("failed to create private key: %w", err)
return "", fmt.Errorf("failed to create link: %w", err)
}
var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey,
ListenPort: &params.Port,
PrivateKey: privKey,
ListenPort: &port,
}
err = m.client.ConfigureDevice(params.IfName, cfg)
err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil {
return fmt.Errorf("failed to configure dev: %w", err)
m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
}
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName)
return nil
logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return md5Str, nil
}
// Add an address to the given interface

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}