1
0
forked from extern/smegmesh

Compare commits

...

27 Commits

Author SHA1 Message Date
d8e156f13f 36-add-route-path-into-route-object
Added the route path into the route object so that we can
see what meshes packets are routed across.
2023-11-27 18:55:41 +00:00
3fca49a1c9 Merge pull request #35 from tim-beatham/34-fix-routing
34 fix routing
2023-11-27 16:05:06 +00:00
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
45 changed files with 1421 additions and 742 deletions

View File

@ -8,4 +8,5 @@ RUN apt-get update && apt-get install -y \
tmux \ tmux \
vim vim
WORKDIR /wgmesh WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./... RUN go build -o /usr/local/bin ./...

View File

@ -7,11 +7,13 @@ import (
) )
func main() { func main() {
apiServer, err := api.NewSmegServer() apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil { if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
apiServer.Run(":40000") apiServer.Run(":8080")
} }

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

View File

@ -8,7 +8,9 @@ import (
"time" "time"
"github.com/akamensky/argparse" "github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
@ -16,7 +18,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct { type CreateMeshParams struct {
Client *ipcRpc.Client Client *ipcRpc.Client
IfName string
WgPort int WgPort int
Endpoint string Endpoint string
} }
@ -24,7 +25,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string { func createMesh(args *CreateMeshParams) string {
var reply string var reply string
newMeshParams := ipc.NewMeshArgs{ newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort, WgPort: args.WgPort,
Endpoint: args.Endpoint, Endpoint: args.Endpoint,
} }
@ -68,7 +68,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{ args := ipc.JoinMeshArgs{
MeshId: params.MeshId, MeshId: params.MeshId,
IpAdress: params.IpAddress, IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort, Port: params.WgPort,
} }
@ -96,9 +95,13 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint) fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint) fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost) fmt.Println("Wg IP: " + node.WgHost)
fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())) fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
advertiseRoutes := strings.Join(node.Routes, ",") mapFunc := func(r ctrlserver.MeshRoute) string {
return r.Destination
}
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
fmt.Printf("Routes: %s\n", advertiseRoutes) fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---") fmt.Println("---")
@ -240,7 +243,6 @@ func main() {
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh") newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to") listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network") joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface") enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format") getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network") leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
@ -251,14 +253,12 @@ func main() {
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh") getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -298,7 +298,6 @@ func main() {
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{ fmt.Println(createMesh(&CreateMeshParams{
Client: client, Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort, WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
})) }))
@ -311,7 +310,6 @@ func main() {
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{ fmt.Println(joinMesh(&JoinMeshParams{
Client: client, Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort, WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,

View File

@ -13,7 +13,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync" "github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp" timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@ -57,8 +57,9 @@ func main() {
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer) timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf) pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
routeScheduler := timer.NewRouteScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -78,6 +79,7 @@ func main() {
go syncScheduler.Run() go syncScheduler.Run()
go timestampScheduler.Run() go timestampScheduler.Run()
go pruneScheduler.Run() go pruneScheduler.Run()
go routeScheduler.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("Closing resources")

View File

@ -10,6 +10,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
) )
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"
@ -22,11 +23,36 @@ type ApiServer interface {
type SmegServer struct { type SmegServer struct {
router *gin.Engine router *gin.Engine
client *ipcRpc.Client client *ipcRpc.Client
words *what8words.What8Words
} }
func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode { func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
if route.Path == nil {
route.Path = make([]string, 0)
}
routes[index] = Route{
Prefix: route.Destination,
Path: route.Path,
}
}
return routes
}
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil { if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0) meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
} }
return &SmegNode{ return &SmegNode{
@ -35,20 +61,20 @@ func meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
Endpoint: meshNode.HostEndpoint, Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp), Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description, Description: meshNode.Description,
Routes: meshNode.Routes, Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey, PublicKey: meshNode.PublicKey,
Alias: meshNode.Alias, Alias: alias,
Services: meshNode.Services, Services: meshNode.Services,
} }
} }
func meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh { func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh var smegMesh SmegMesh
smegMesh.MeshId = meshId smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode) smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes { for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *meshNodeToAPIMeshNode(node) smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
} }
return smegMesh return smegMesh
@ -62,11 +88,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(), "error": err.Error(),
}) })
return return
} }
ipcRequest := ipc.NewMeshArgs{ ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort, WgPort: createMesh.WgPort,
} }
@ -100,7 +126,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
ipcRequest := ipc.JoinMeshArgs{ ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId, MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap, IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort, Port: joinMesh.WgPort,
} }
@ -139,7 +164,7 @@ func (s *SmegServer) GetMesh(c *gin.Context) {
return return
} }
mesh := meshToAPIMesh(meshidParam, getMeshReply.Nodes) mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh) c.JSON(http.StatusOK, mesh)
} }
@ -168,7 +193,7 @@ func (s *SmegServer) GetMeshes(c *gin.Context) {
return return
} }
meshes = append(meshes, meshToAPIMesh(mesh, getMeshReply.Nodes)) meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
} }
c.JSON(http.StatusOK, meshes) c.JSON(http.StatusOK, meshes)
@ -179,13 +204,19 @@ func (s *SmegServer) Run(addr string) error {
return s.router.Run(addr) return s.router.Run(addr)
} }
func NewSmegServer() (ApiServer, error) { func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr) client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default() router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{ router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
@ -195,6 +226,7 @@ func NewSmegServer() (ApiServer, error) {
smegServer := &SmegServer{ smegServer := &SmegServer{
router: router, router: router,
client: client, client: client,
words: words,
} }
router.GET("/meshes", smegServer.GetMeshes) router.GET("/meshes", smegServer.GetMeshes)

View File

@ -1,5 +1,10 @@
package api package api
type Route struct {
Prefix string `json:"prefix"`
Path []string `json:"path"`
}
type SmegNode struct { type SmegNode struct {
Alias string `json:"alias"` Alias string `json:"alias"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
@ -8,7 +13,7 @@ type SmegNode struct {
Timestamp int `json:"timestamp"` Timestamp int `json:"timestamp"`
Description string `json:"description"` Description string `json:"description"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Routes []string `json:"routes"` Routes []Route `json:"routes"`
Services map[string]string `json:"services"` Services map[string]string `json:"services"`
} }
@ -18,13 +23,15 @@ type SmegMesh struct {
} }
type CreateMeshRequest struct { type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
} }
type JoinMeshRequest struct { type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"` Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"` MeshId string `json:"meshid" binding:"required"`
} }
type ApiServerConf struct {
WordsFile string
}

View File

@ -35,15 +35,55 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt") panic("node must be of type *MeshNodeCrdt")
} }
crdt.Routes = make(map[string]interface{}) crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string) crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt) c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
} }
func (c *CrdtMeshManager) GetNodeIds() []string { func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
keepAliveTime := timestamp.Int64()
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys() keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys return keys
} }
@ -55,7 +95,7 @@ func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return nil, err return nil, err
} }
if c.cache == nil || len(changes) > 3 { if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root()) cache, err := automerge.As[*MeshCrdt](c.doc.Root())
@ -113,7 +153,7 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
// NodeExists: returns true if the node exists. Returns false // NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key) node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err != nil return node.Kind() == automerge.KindMap && err == nil
} }
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) { func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
@ -279,7 +319,7 @@ func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
} }
// AddRoutes: adds routes to the specific nodeId // AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId) logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -298,7 +338,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
} }
for _, route := range routes { for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{}) err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
if err != nil { if err != nil {
return err return err
@ -307,6 +350,67 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return nil return nil
} }
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.Path, m.GetMeshId()),
}
}
}
}
return routes, nil
}
// DeleteRoutes deletes the specified routes // DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -419,8 +523,13 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp return m.Timestamp
} }
func (m *MeshNodeCrdt) GetRoutes() []string { func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.MapKeys(m.Routes) return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
Destination: r.Destination,
Path: r.Path,
}
})
} }
func (m *MeshNodeCrdt) GetDescription() string { func (m *MeshNodeCrdt) GetDescription() string {
@ -450,6 +559,12 @@ func (m *MeshNodeCrdt) GetServices() map[string]string {
return services return services
} }
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -464,8 +579,22 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Description: node.Description, Description: node.Description,
Alias: node.Alias, Alias: node.Alias,
Services: node.GetServices(), Services: node.GetServices(),
Type: node.Type,
} }
} }
return nodes return nodes
} }
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
func (r *Route) GetHopCount() int {
return len(r.Path)
}
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -28,16 +28,23 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort), HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(), PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty. // Always set the routes as empty.
// Routes handled by external component // Routes handled by external component
Routes: map[string]interface{}{}, Routes: make(map[string]Route),
Description: "", Description: "",
Alias: "", Alias: "",
Type: string(f.Config.Role),
} }
} }
@ -50,7 +57,19 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else if len(f.Config.Endpoint) != 0 { } else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint hostName = f.Config.Endpoint
} else { } else {
hostName = lib.GetOutboundIP().String() ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
} }
return hostName return hostName

View File

@ -1,16 +1,23 @@
package crdt package crdt
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
Path []string `automerge:"path"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes // MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct { type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"` HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"` WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"` PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"` Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"` Routes map[string]Route `automerge:"routes"`
Alias string `automerge:"alias"` Alias string `automerge:"alias"`
Description string `automerge:"description"` Description string `automerge:"description"`
Services map[string]string `automerge:"services"` Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
} }
// MeshCrdt: Represents the mesh network as a whole // MeshCrdt: Represents the mesh network as a whole

View File

@ -16,6 +16,20 @@ func (m *WgMeshConfigurationError) Error() string {
return m.msg return m.msg
} }
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
)
type WgMeshConfiguration struct { type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS // CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"` CertificatePath string `yaml:"certificatePath"`
@ -28,6 +42,9 @@ type WgMeshConfiguration struct {
SkipCertVerification bool `yaml:"skipCertVerification"` SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on // Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"` GrpcPort string `yaml:"gRPCPort"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
// use public IP discovery library
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes // AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"` AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable. // Endpoint is the IP in which this computer is publicly reachable.
@ -47,12 +64,21 @@ type WgMeshConfiguration struct {
KeepAliveTime int `yaml:"keepAliveTime"` KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead // Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"` Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead // PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"` PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
// Profile whether or not to include a http server that profiles the code // Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"` Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types // StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"` StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { func ValidateConfiguration(c *WgMeshConfiguration) error {
@ -122,9 +148,15 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
} }
} }
if c.PruneTime <= 1 { if c.PruneTime < 1 {
return &WgMeshConfigurationError{ return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1", msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
} }
} }
@ -134,6 +166,14 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
} }
} }
if c.Role == "" {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil return nil
} }

View File

@ -31,11 +31,11 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := crdt.MeshNodeFactory{ nodeFactory := crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
idGenerator := &lib.UUIDGenerator{} idGenerator := &lib.IDNameGenerator{}
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer() configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,

View File

@ -9,6 +9,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type MeshRoute struct {
Destination string
Path []string
}
// Represents a WireGuard MeshNode // Represents a WireGuard MeshNode
type MeshNode struct { type MeshNode struct {
HostEndpoint string HostEndpoint string
@ -16,7 +21,7 @@ type MeshNode struct {
PublicKey string PublicKey string
WgHost string WgHost string
Timestamp int64 Timestamp int64
Routes []string Routes []MeshRoute
Description string Description string
Alias string Alias string
Services map[string]string Services map[string]string

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

View File

@ -4,13 +4,13 @@ package rpctypes;
option go_package = "pkg/rpc"; option go_package = "pkg/rpc";
service MeshCtrlServer { service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {} rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
} }
message JoinMeshRequest { message GetMeshRequest {
string meshId = 2; string meshId = 1;
} }
message JoinMeshReply { message GetMeshReply {
bool success = 1; bytes mesh = 1;
} }

View File

@ -11,8 +11,6 @@ import (
) )
type NewMeshArgs struct { type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
WgPort int WgPort int
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
@ -25,13 +23,14 @@ type JoinMeshArgs struct {
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose // Port is the WireGuard port to expose
Port int Port int
// Endpoint is the routable address of this machine. If not provided // Endpoint is the routable address of this machine. If not provided
// defaults to the default address // defaults to the default address
Endpoint string Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
} }
type PutServiceArgs struct { type PutServiceArgs struct {
@ -63,7 +62,6 @@ type MeshIpc interface {
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error
GetDOT(meshId string, reply *string) error GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error PutDescription(description string, reply *string) error

View File

@ -66,3 +66,13 @@ func Filter[V any](list []V, f filterFunc[V]) []V {
return newList return newList
} }
func Contains[V any](list []V, proposition func(V) bool) bool {
for _, elem := range list {
if proposition(elem) {
return true
}
}
return false
}

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
bucketFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -1,6 +1,9 @@
package lib package lib
import "github.com/google/uuid" import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
)
// IdGenerator generates unique ids // IdGenerator generates unique ids
type IdGenerator interface { type IdGenerator interface {
@ -15,3 +18,11 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New() id := uuid.New()
return id.String(), nil return id.String(), nil
} }
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -1,17 +1,61 @@
package lib package lib
import ( import (
"encoding/json"
"io"
"log" "log"
"net" "net"
"net/http"
) )
// GetOutboundIP: gets the oubound IP of this packet // GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP { func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80") conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer conn.Close() defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr) localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
} }

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

View File

@ -201,7 +201,7 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to delete route %w", err) return fmt.Errorf("failed to delete route %s", dst.IP.String())
} }
return nil return nil
@ -219,22 +219,15 @@ func (r1 Route) equal(r2 Route) bool {
// DeleteRoutes deletes all routes not in exclude // DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error { func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0) routes, err := c.listRoutes(ifName, family)
if len(exclude) != 0 { if err != nil {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway) return err
if err != nil {
return err
}
routes = lRoutes
} }
ifRoutes := make([]Route, 0) ifRoutes := make([]Route, 0)
for _, rtRoute := range routes { for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128 maskSize := 128
if family == unix.AF_INET { if family == unix.AF_INET {
@ -262,7 +255,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
toDelete := Filter(ifRoutes, shouldExclude) toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete { for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String()) logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
err := c.DeleteRoute(ifName, route) err := c.DeleteRoute(ifName, route)
if err != nil { if err != nil {
@ -274,7 +267,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
} }
// listRoutes lists all routes on the interface // listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) { func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
if err != nil { if err != nil {
@ -288,7 +281,7 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP
} }
filterFunc := func(r rtnetlink.RouteMessage) bool { filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index) return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
} }
routes = Filter(routes, filterFunc) routes = Filter(routes, filterFunc)

View File

@ -3,7 +3,13 @@ package mesh
import ( import (
"fmt" "fmt"
"net" "net"
"slices"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -16,10 +22,24 @@ type MeshConfigApplyer interface {
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplyer struct {
meshManager MeshManager meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
} }
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { type routeNode struct {
gateway string
route Route
}
func (r *routeNode) equals(route2 *routeNode) bool {
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil { if err != nil {
@ -35,20 +55,109 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
allowedips := make([]net.IPNet, 1) allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost() allowedips[0] = *node.GetWgHost()
clients, ok := peerToClients[node.GetWgHost().String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() { for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route) bestRoutes := routes[route.GetDestination().String()]
allowedips = append(allowedips, *ipnet)
if len(bestRoutes) == 1 {
allowedips = append(allowedips, *route.GetDestination())
} else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *route.GetDestination())
}
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
if existing != -1 {
endpoint = device.Peers[existing].Endpoint
} }
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey, PublicKey: pubKey,
Endpoint: endpoint, Endpoint: endpoint,
AllowedIPs: allowedips, AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
} }
return &peerConfig, nil return &peerConfig, nil
} }
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode)
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey()
meshRoutes, _ := meshProvider.GetRoutes(pubKey.String())
for _, route := range meshRoutes {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix == nil || route == nil || route.GetDestination() == nil {
return false
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
}
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh() snap, err := mesh.GetMesh()
@ -56,18 +165,67 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err return err
} }
nodes := snap.GetNodes() nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0 var count int = 0
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
peerToClients := make(map[string][]net.IPNet)
routes := m.getRoutes(mesh)
installedRoutes := make([]lib.Route, 0)
for _, n := range nodes { for _, n := range nodes {
peer, err := convertMeshNode(n) if NodeEquals(n, self) {
continue
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
hashFunc := func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
}
peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc)
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
dev, _ := mesh.GetDevice()
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil { if err != nil {
return err return err
} }
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
peerConfigs[count] = *peer peerConfigs[count] = *peer
count++ count++
} }
@ -82,6 +240,12 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err return err
} }
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
if err != nil {
return err
}
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
} }
@ -112,7 +276,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{ m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true, ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1), Peers: make([]wgtypes.PeerConfig, 0),
}) })
return nil return nil
@ -122,6 +286,9 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager m.meshManager = manager
} }
func NewWgMeshConfigApplyer() MeshConfigApplyer { func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
return &WgMeshConfigApplyer{} return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
}
} }

View File

@ -61,7 +61,7 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId strin
self, _ := c.manager.GetSelf(meshId) self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() { if NodeEquals(self, node) {
return return
} }

View File

@ -14,11 +14,10 @@ import (
) )
type MeshManager interface { type MeshManager interface {
CreateMesh(devName string, port int) (string, error) CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error
GetPublicKey(meshId string) (*wgtypes.Key, error) GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
@ -35,6 +34,7 @@ type MeshManager interface {
Close() error Close() error
GetMonitor() MeshMonitor GetMonitor() MeshMonitor
GetNode(string, string) MeshNode GetNode(string, string) MeshNode
GetRouteManager() RouteManager
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
@ -54,10 +54,15 @@ type MeshManagerImpl struct {
Monitor MeshMonitor Monitor MeshMonitor
} }
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
}
// RemoveService implements MeshManager. // RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error { func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes { for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.HostEndpoint, service) err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil { if err != nil {
return err return err
@ -70,7 +75,7 @@ func (m *MeshManagerImpl) RemoveService(service string) error {
// SetService implements MeshManager. // SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error { func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes { for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.HostEndpoint, service, value) err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil { if err != nil {
return err return err
@ -115,15 +120,25 @@ func (m *MeshManagerImpl) Prune() error {
} }
// CreateMesh: Creates a new mesh, stores it and returns the mesh id // CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil { if err != nil {
return "", err return "", err
} }
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName, DevName: ifName,
Port: port, Port: port,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -134,32 +149,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
if !m.conf.StubWg {
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
return meshId, nil return meshId, nil
} }
type AddMeshParams struct { type AddMeshParams struct {
MeshId string MeshId string
DevName string
WgPort int WgPort int
MeshBytes []byte MeshBytes []byte
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName, DevName: ifName,
Port: params.WgPort, Port: params.WgPort,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -177,14 +191,6 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
} }
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
if !m.conf.StubWg {
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
}
return nil return nil
} }
@ -199,23 +205,6 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
return theMesh return theMesh
} }
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh // GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg { if s.conf.StubWg {
@ -255,20 +244,26 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId) return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
} }
pubKey, err := s.GetPublicKey(params.MeshId) if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil { if err != nil {
return err return err
}
params.WgPort = device.ListenPort
} }
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId) pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
if err != nil { if err != nil {
return err return err
} }
node := s.nodeFactory.Build(&MeshNodeFactoryParams{ node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey, PublicKey: &pubKey,
NodeIP: nodeIP, NodeIP: nodeIP,
WgPort: params.WgPort, WgPort: params.WgPort,
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
@ -300,22 +295,23 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return fmt.Errorf("mesh %s does not exist", meshId) return fmt.Errorf("mesh %s does not exist", meshId)
} }
err := s.RouteManager.RemoveRoutes(meshId) var err error
if err != nil {
return err
}
if !s.conf.StubWg { if !s.conf.StubWg {
device, e := mesh.GetDevice() device, err := mesh.GetDevice()
if e != nil { if err != nil {
return err return err
} }
err = s.interfaceManipulator.RemoveInterface(device.Name) err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
} }
err = s.RouteManager.RemoveRoutes(meshId)
delete(s.Meshes, meshId) delete(s.Meshes, meshId)
return err return err
} }
@ -327,7 +323,8 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
} }
node, err := meshInstance.GetNode(s.HostParameters.HostEndpoint) logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return nil, errors.New("the node doesn't exist in the mesh") return nil, errors.New("the node doesn't exist in the mesh")
@ -337,6 +334,10 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
} }
func (s *MeshManagerImpl) ApplyConfig() error { func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
err := s.configApplyer.ApplyConfig() err := s.configApplyer.ApplyConfig()
if err != nil { if err != nil {
@ -348,8 +349,8 @@ func (s *MeshManagerImpl) ApplyConfig() error {
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) { if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil { if err != nil {
return err return err
@ -363,8 +364,8 @@ func (s *MeshManagerImpl) SetDescription(description string) error {
// SetAlias implements MeshManager. // SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error { func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) { if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.HostEndpoint, alias) err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil { if err != nil {
return err return err
@ -377,8 +378,8 @@ func (s *MeshManagerImpl) SetAlias(alias string) error {
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.HostEndpoint) { if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint) err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return err return err
@ -435,17 +436,11 @@ type NewMeshManagerParams struct {
// Creates a new instance of a mesh manager with the given parameters // Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) MeshManager { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{} privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
switch params.Conf.Endpoint { PrivateKey: &privateKey,
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
} }
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManagerImpl{ m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider), Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams, HostParameters: &hostParams,

View File

@ -64,7 +64,6 @@ func TestAddMeshAddsAMesh(t *testing.T) {
manager.AddMesh(&AddMeshParams{ manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -83,7 +82,6 @@ func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -106,7 +104,6 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -175,7 +172,6 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -201,8 +197,8 @@ func TestSetDescription(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
description := "wooooo" description := "wooooo"
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh("wg0", 5001) meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
@ -225,8 +221,8 @@ func TestSetDescription(t *testing.T) {
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh("wg0", 5001) meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,

View File

@ -1,25 +1,18 @@
package mesh package mesh
import ( import (
"fmt"
"net"
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix"
) )
type RouteManager interface { type RouteManager interface {
UpdateRoutes() error UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error RemoveRoutes(meshId string) error
} }
type RouteManagerImpl struct { type RouteManagerImpl struct {
meshManager MeshManager meshManager MeshManager
routeInstaller route.RouteInstaller
} }
func (r *RouteManagerImpl) UpdateRoutes() error { func (r *RouteManagerImpl) UpdateRoutes() error {
@ -27,6 +20,24 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
ulaBuilder := new(ip.ULABuilder) ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes { for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
pubKey, err := self.GetPublicKey()
if err != nil {
return err
}
routes, err := mesh1.GetRoutes(pubKey.String())
if err != nil {
return err
}
for _, mesh2 := range meshes { for _, mesh2 := range meshes {
if mesh1 == mesh2 { if mesh1 == mesh2 {
continue continue
@ -39,13 +50,11 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err return err
} }
self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
Destination: ipNet,
if err != nil { HopCount: 0,
return err Path: make([]string, 0),
} })...)
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String())
if err != nil { if err != nil {
return err return err
@ -74,111 +83,11 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
return err return err
} }
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String()) mesh1.RemoveRoutes(NodeID(self), ipNet.String())
} }
return nil return nil
} }
// AddRoute adds a route to the given interface
func (m *RouteManagerImpl) addRoute(ifName string, meshPrefix string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
// Delete any routes that may be vacant
err = rtnl.DeleteRoutes(ifName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
if route.Destination.String() == meshPrefix {
continue
}
err = rtnl.AddRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshNode) error {
routeMapFunc := func(route string) lib.Route {
_, cidr, _ := net.ParseCIDR(route)
r := lib.Route{
Destination: *cidr,
Gateway: node.GetWgHost().IP,
}
return r
}
ipBuilder := &ip.ULABuilder{}
ipNet, err := ipBuilder.GetIPNet(meshid)
if err != nil {
return err
}
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...)
}
func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
mesh, err := meshProvider.GetMesh()
if err != nil {
return err
}
dev, err := meshProvider.GetDevice()
if err != nil {
return err
}
self, err := m.meshManager.GetSelf(meshProvider.GetMeshId())
if err != nil {
return err
}
for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() {
continue
}
err = m.installRoute(dev.Name, meshProvider.GetMeshId(), node)
if err != nil {
return err
}
}
return nil
}
// InstallRoutes installs all routes to the RIB
func (r *RouteManagerImpl) InstallRoutes() error {
for _, mesh := range r.meshManager.GetMeshes() {
err := r.installRoutes(mesh)
if err != nil {
return err
}
}
return nil
}
func NewRouteManager(m MeshManager) RouteManager { func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()} return &RouteManagerImpl{meshManager: m}
} }

View File

@ -16,11 +16,16 @@ type MeshNodeStub struct {
wgEndpoint string wgEndpoint string
wgHost *net.IPNet wgHost *net.IPNet
timeStamp int64 timeStamp int64
routes []string routes []Route
identifier string identifier string
description string description string
} }
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return conf.PEER_ROLE
}
// GetServices implements MeshNode. // GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string { func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string) return make(map[string]string)
@ -51,7 +56,7 @@ func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp return m.timeStamp
} }
func (m *MeshNodeStub) GetRoutes() []string { func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes return m.routes
} }
@ -76,29 +81,33 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider. // GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetNodeIds() []string { func (*MeshProviderStub) GetPeers() []string {
panic("unimplemented") return make([]string, 0)
} }
// GetNode implements MeshProvider. // GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) { func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
panic("unimplemented") return nil, nil
} }
// NodeExists implements MeshProvider. // NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool { func (*MeshProviderStub) NodeExists(string) bool {
panic("unimplemented") return false
} }
// AddService implements MeshProvider. // AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error { func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
panic("unimplemented") return nil
} }
// RemoveService implements MeshProvider. // RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error { func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
panic("unimplemented") return nil
} }
// SetAlias implements MeshProvider. // SetAlias implements MeshProvider.
@ -108,7 +117,7 @@ func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented") return nil
} }
// Prune implements MeshProvider. // Prune implements MeshProvider.
@ -154,7 +163,7 @@ func (s *MeshProviderStub) HasChanges() bool {
return false return false
} }
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error { func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil return nil
} }
@ -188,7 +197,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort), wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost, wgHost: wgHost,
timeStamp: time.Now().Unix(), timeStamp: time.Now().Unix(),
routes: make([]string, 0), routes: make([]Route, 0),
identifier: "abc", identifier: "abc",
description: "A Mesh Node Stub", description: "A Mesh Node Stub",
} }
@ -211,6 +220,11 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
}
// GetNode implements MeshManager. // GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode { func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented") panic("unimplemented")
@ -250,7 +264,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)} return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
} }
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil return "tim123", nil
} }
@ -273,10 +287,6 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}} snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
} }
func (m *MeshManagerStub) EnableInterface(meshId string) error {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey() key, _ := wgtypes.GenerateKey()
return &key, nil return &key, nil

View File

@ -4,13 +4,39 @@ package mesh
import ( import (
"net" "net"
"slices"
"github.com/tim-beatham/wgmesh/pkg/conf" "github.com/tim-beatham/wgmesh/pkg/conf"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
// GetPath: get a list of AS paths to get to the destination
GetPath() []string
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
Path []string
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
}
func (r *RouteStub) GetPath() []string {
return r.Path
}
// MeshNode represents an implementation of a node in a mesh // MeshNode represents an implementation of a node in a mesh
type MeshNode interface { type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node // GetHostEndpoint: gets the gRPC endpoint of the node
@ -24,7 +50,7 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe // GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64 GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides // GetRoutes: returns the routes that the nodes provides
GetRoutes() []string GetRoutes() []Route
// GetIdentifier: returns the identifier of the node // GetIdentifier: returns the identifier of the node
GetIdentifier() string GetIdentifier() string
// GetDescription: returns the description for this node // GetDescription: returns the description for this node
@ -34,46 +60,25 @@ type MeshNode interface {
GetAlias() string GetAlias() string
// GetServices: returns a list of services offered by the node // GetServices: returns a list of services offered by the node
GetServices() map[string]string GetServices() map[string]string
GetType() conf.NodeType
} }
// NodeEquals: determines if two mesh nodes are equivalent to one another // NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool { func NodeEquals(node1, node2 MeshNode) bool {
if node1.GetHostEndpoint() != node2.GetHostEndpoint() { key1, _ := node1.GetPublicKey()
return false key2, _ := node2.GetPublicKey()
}
node1Pub, _ := node1.GetPublicKey() return key1.String() == key2.String()
node2Pub, _ := node2.GetPublicKey() }
if node1Pub != node2Pub { func RouteEquals(route1, route2 Route) bool {
return false return route1.GetDestination().String() == route2.GetDestination().String() &&
} route1.GetHopCount() == route2.GetHopCount()
}
if node1.GetWgEndpoint() != node2.GetWgEndpoint() { func NodeID(node MeshNode) string {
return false key, _ := node.GetPublicKey()
} return key.String()
if node1.GetWgHost() != node2.GetWgHost() {
return false
}
if !slices.Equal(node1.GetRoutes(), node2.GetRoutes()) {
return false
}
if node1.GetIdentifier() != node2.GetIdentifier() {
return false
}
if node1.GetDescription() != node2.GetDescription() {
return false
}
if node1.GetAlias() != node2.GetAlias() {
return false
}
return true
} }
type MeshSnapshot interface { type MeshSnapshot interface {
@ -109,7 +114,7 @@ type MeshProvider interface {
// UpdateTimeStamp: update the timestamp of the given node // UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node // AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node // DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the automerge syncer for sync
@ -129,12 +134,20 @@ type MeshProvider interface {
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds // pruneAmount seconds
Prune(pruneAmount int) error Prune(pruneAmount int) error
GetNodeIds() []string // GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node
type HostParameters struct { type HostParameters struct {
HostEndpoint string PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
} }
// MeshProviderFactoryParams parameters required to build a mesh provider // MeshProviderFactoryParams parameters required to build a mesh provider

View File

@ -3,8 +3,10 @@ package query
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strings"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
) )
@ -23,16 +25,23 @@ type QueryError struct {
msg string msg string
} }
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
type QueryNode struct { type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"` HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"` WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"` Timestamp int64 `json:"timestamp"`
Description string `json:"description"` Description string `json:"description"`
Routes []string `json:"routes"` Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"` Alias string `json:"alias"`
Services map[string]string `json:"services"` Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
} }
func (m *QueryError) Error() string { func (m *QueryError) Error() string {
@ -76,10 +85,17 @@ func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String() queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp() queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes() queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
Path: strings.Join(r.GetPath(), ","),
}
})
queryNode.Description = node.GetDescription() queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias() queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices() queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode return queryNode
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query" "github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
@ -20,7 +21,7 @@ type IpcHandler struct {
} }
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort) meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil { if err != nil {
return err return err
@ -72,7 +73,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second) configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel() defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -83,7 +86,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId, MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port, WgPort: args.Port,
MeshBytes: meshReply.Mesh, MeshBytes: meshReply.Mesh,
}) })
@ -118,19 +120,19 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
} }
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId) theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if mesh == nil { if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId) return fmt.Errorf("mesh %s does not exist", meshId)
} }
meshSnapshot, err := mesh.GetMesh() meshSnapshot, err := theMesh.GetMesh()
if err != nil { if err != nil {
return err return err
} }
if mesh == nil { if theMesh == nil {
return errors.New("mesh does not exist") return errors.New("mesh does not exist")
} }
@ -150,10 +152,15 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
PublicKey: pubKey.String(), PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(), WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(), Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(), Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute {
Description: node.GetDescription(), return ctrlserver.MeshRoute{
Alias: node.GetAlias(), Destination: r.GetDestination().String(),
Services: node.GetServices(), Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
} }
nodes[i] = node nodes[i] = node
@ -164,18 +171,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil return nil
} }
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error { func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager()) g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())

View File

@ -28,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil return &reply, nil
} }
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1,22 +1,32 @@
package route package route
import ( import (
"net" "github.com/tim-beatham/wgmesh/pkg/lib"
"os/exec" "golang.org/x/sys/unix"
logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
type RouteInstaller interface { type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error InstallRoutes(devName string, routes ...lib.Route) error
} }
type RouteInstallerImpl struct{} type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table // InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error { func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes { for _, route := range routes {
err := r.installRoute(devName, route) err := rtnl.AddRoute(devName, route)
if err != nil { if err != nil {
return err return err
@ -26,22 +36,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil return nil
} }
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller { func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{} return &RouteInstallerImpl{}
} }

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct { type GetMeshRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() { func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{} *x = GetMeshRequest{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {} func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message { func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) { func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
} }
func (x *GetMeshRequest) GetMeshId() string { func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() { func (x *GetMeshReply) Reset() {
*x = GetMeshReply{} *x = GetMeshReply{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {} func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message { func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) { func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
} }
func (x *GetMeshReply) GetMesh() []byte { func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil return nil
} }
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x6f, 0x33,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
} }
var ( var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode (*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest (*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest 0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest 1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply 1, // [1:2] is the sub-list for method output_type
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply 0, // [0:1] is the sub-list for method input_type
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name 0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
} }
if !protoimpl.UnsafeEnabled { if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i { switch v := v.(*GetMeshRequest); i {
case 0: case 0:
return &v.state return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i { switch v := v.(*GetMeshReply); i {
case 0: case 0:
return &v.state return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc, RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0, NumEnums: 0,
NumMessages: 5, NumMessages: 2,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface { type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
} }
type meshCtrlServerClient struct { type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil return out, nil
} }
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service. // MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer // All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility // for forward compatibility
type MeshCtrlServerServer interface { type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer() mustEmbedUnimplementedMeshCtrlServerServer()
} }
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) { func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
} }
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {} func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service. // UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service. // MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh", MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler, Handler: _MeshCtrlServer_GetMesh_Handler,
}, },
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto", Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -44,15 +44,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
} }
} }
nodeNames := s.manager.GetMesh(meshId).GetNodeIds() nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId) self, err := s.manager.GetSelf(meshId)
if err != nil { if err != nil {
return err return err
} }
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint()) selfPublickey, err := self.GetPublicKey()
if err != nil {
return err
}
neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset { for _, node := range randomSubset {
@ -63,7 +68,7 @@ func (s *SyncerImpl) Sync(meshId string) error {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster") logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint()) interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String())
randomSubset = append(randomSubset, interCluster) randomSubset = append(randomSubset, interCluster)
} }
@ -74,7 +79,14 @@ func (s *SyncerImpl) Sync(meshId string) error {
go func(i int) error { go func(i int) error {
defer waitGroup.Done() defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, randomSubset[i])
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
}
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
return err return err
}(index) }(index)
} }

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{ err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: ifName,
WgPort: port, WgPort: port,
MeshBytes: reply.Mesh, MeshBytes: reply.Mesh,
}) })

View File

@ -5,20 +5,6 @@ import (
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
) )
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler. // Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc { func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error { return func() error {

View File

@ -64,11 +64,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer() syncer = mesh.GetSyncer()
} else if meshId != in.MeshId { } else if meshId != in.MeshId {
return errors.New("Differing MeshIDs") return errors.New("differing meshids")
} }
if syncer == nil { if syncer == nil {
return errors.New("Syncer should not be nil") return errors.New("syncer should not be nil")
} }
msg, moreMessages := syncer.GenerateMessage() msg, moreMessages := syncer.GenerateMessage()

View File

@ -1,4 +1,4 @@
package timestamp package timer
import ( import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
@ -12,3 +12,11 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
} }
func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes()
}
return *lib.NewTimer(timerFunc, 10)
}

View File

@ -2,8 +2,8 @@ package wg
type WgInterfaceManipulatorStub struct{} type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error { func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return nil return "", nil
} }
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error { func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {

View File

@ -1,5 +1,7 @@
package wg package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct { type WgError struct {
msg string msg string
} }
@ -8,14 +10,9 @@ func (m *WgError) Error() string {
return m.msg return m.msg
} }
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface { type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name // AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface // RemoveInterface removes the specified interface

View File

@ -1,6 +1,9 @@
package wg package wg
import ( import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
@ -13,40 +16,48 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client client *wgctrl.Client
} }
const hashLength = 6
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig() rtnl, err := lib.NewRtNetlinkConfig()
if err != nil { if err != nil {
return fmt.Errorf("failed to access link: %w", err) return "", fmt.Errorf("failed to access link: %w", err)
} }
defer rtnl.Close() defer rtnl.Close()
err = rtnl.CreateLink(params.IfName) randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed to create link: %w", err) return "", err
} }
privateKey, err := wgtypes.GeneratePrivateKey() md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil { if err != nil {
return fmt.Errorf("failed to create private key: %w", err) return "", fmt.Errorf("failed to create link: %w", err)
} }
var cfg wgtypes.Config = wgtypes.Config{ var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey, PrivateKey: privKey,
ListenPort: &params.Port, ListenPort: &port,
} }
err = m.client.ConfigureDevice(params.IfName, cfg) err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure dev: %w", err) m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
} }
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return nil return md5Str, nil
} }
// Add an address to the given interface // Add an address to the given interface

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}