Compare commits

..

35 Commits

Author SHA1 Message Date
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
5f176e731f Developed a rest API 2023-11-13 10:44:14 +00:00
44f119b45c Updating examples 2023-11-08 09:19:24 +00:00
5215d5d54d Merge pull request #14 from tim-beatham/13-netlink-api
Removed interface manipulation via os.Exec into
2023-11-07 19:53:39 +00:00
57 changed files with 2581 additions and 656 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

12
Containerfile Normal file
View File

@ -0,0 +1,12 @@
FROM docker.io/library/golang:bookworm
COPY ./ /wgmesh
RUN apt-get update && apt-get install -y \
wireguard \
wireguard-tools \
iproute2 \
iputils-ping \
tmux \
vim
WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./...

1
Dockerfile Symbolic link
View File

@ -0,0 +1 @@
Containerfile

19
cmd/api/main.go Normal file
View File

@ -0,0 +1,19 @@
package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
)
func main() {
apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":8080")
}

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

View File

@ -16,7 +16,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct { type CreateMeshParams struct {
Client *ipcRpc.Client Client *ipcRpc.Client
IfName string
WgPort int WgPort int
Endpoint string Endpoint string
} }
@ -24,7 +23,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string { func createMesh(args *CreateMeshParams) string {
var reply string var reply string
newMeshParams := ipc.NewMeshArgs{ newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort, WgPort: args.WgPort,
Endpoint: args.Endpoint, Endpoint: args.Endpoint,
} }
@ -68,7 +66,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{ args := ipc.JoinMeshArgs{
MeshId: params.MeshId, MeshId: params.MeshId,
IpAdress: params.IpAddress, IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort, Port: params.WgPort,
} }
@ -171,6 +168,68 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply) fmt.Println(reply)
} }
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() { func main() {
parser := argparse.NewParser("wg-mesh", parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes") "wg-mesh Manipulate WireGuard meshes")
@ -178,25 +237,24 @@ func main() {
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh") newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to") listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network") joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface") enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format") getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network") leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath") queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node") putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true}) var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -208,6 +266,16 @@ func main() {
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true}) var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args) err := parser.Parse(os.Args)
if err != nil { if err != nil {
@ -224,7 +292,6 @@ func main() {
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{ fmt.Println(createMesh(&CreateMeshParams{
Client: client, Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort, WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
})) }))
@ -237,7 +304,6 @@ func main() {
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{ fmt.Println(joinMesh(&JoinMeshParams{
Client: client, Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort, WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,
@ -245,10 +311,6 @@ func main() {
})) }))
} }
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() { if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId) getGraph(client, *getGraphMeshId)
} }
@ -268,4 +330,20 @@ func main() {
if putDescriptionCmd.Happened() { if putDescriptionCmd.Happened() {
putDescription(client, *description) putDescription(client, *description)
} }
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
} }

View File

@ -11,4 +11,4 @@ interClusterChance: 0.15
branchRate: 3 branchRate: 3
infectionCount: 3 infectionCount: 3
keepAliveTime: 10 keepAliveTime: 10
pruneTime: 20 pruneTime: 20

View File

@ -1,7 +1,8 @@
package main package main
import ( import (
"log" "net/http"
_ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
@ -12,7 +13,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin" "github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync" "github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp" timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
) )
@ -35,6 +36,12 @@ func main() {
return return
} }
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl var syncProvider sync.SyncServiceImpl
@ -45,13 +52,14 @@ func main() {
SyncProvider: &syncProvider, SyncProvider: &syncProvider,
Client: client, Client: client,
} }
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
syncProvider.Server = ctrlServer syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer) syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer) timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf) pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
routeScheduler := timer.NewRouteScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{ robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer, CtrlServer: ctrlServer,
@ -65,12 +73,13 @@ func main() {
return return
} }
log.Println("Running IPC Handler") logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc) go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run() go syncScheduler.Run()
go timestampScheduler.Run() go timestampScheduler.Run()
go pruneScheduler.Run() go pruneScheduler.Run()
go routeScheduler.Run()
closeResources := func() { closeResources := func() {
logging.Log.WriteInfof("Closing resources") logging.Log.WriteInfof("Closing resources")

View File

@ -0,0 +1,95 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -0,0 +1,42 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

24
go.mod
View File

@ -5,9 +5,12 @@ go 1.21.3
require ( require (
github.com/akamensky/argparse v1.4.0 github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9 github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0 github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3 github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1 google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0 google.golang.org/protobuf v1.31.0
@ -15,18 +18,33 @@ require (
) )
require ( require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
github.com/jsimonetti/rtnetlink v1.3.5 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.0 // indirect github.com/mdlayher/socket v0.5.0 // indirect
github.com/vishvananda/netns v0.0.0-20191106174202-0a2b9b5464df // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

238
pkg/api/apiserver.go Normal file
View File

@ -0,0 +1,238 @@
package api
import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
word, err := s.words.Convert(route)
if err != nil {
fmt.Println(err.Error())
}
routes[index] = Route{
Prefix: route,
RouteId: word,
}
}
return routes
}
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]string, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
}
return &SmegNode{
WgHost: meshNode.WgHost,
WgEndpoint: meshNode.WgEndpoint,
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey,
Alias: alias,
Services: meshNode.Services,
}
}
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
}
return smegMesh
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
if err := c.ShouldBindJSON(&createMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
WgPort: createMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
}
// JoinMesh: joins a mesh network
func (s *SmegServer) JoinMesh(c *gin.Context) {
var joinMesh JoinMeshRequest
if err := c.ShouldBindJSON(&joinMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
Port: joinMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
}
// GetMesh: given a meshId returns the corresponding mesh
// network.
func (s *SmegServer) GetMesh(c *gin.Context) {
meshidParam := c.Param("meshid")
var meshid string = meshidParam
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
&gin.H{
"error": fmt.Sprintf("could not find mesh %s", meshidParam),
})
return
}
mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes := make([]SmegMesh, 0)
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
}
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
Output: logging.Log.Writer(),
}))
smegServer := &SmegServer{
router: router,
client: client,
words: words,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
return smegServer, nil
}

37
pkg/api/types.go Normal file
View File

@ -0,0 +1,37 @@
package api
type Route struct {
RouteId string `json:"routeId"`
Prefix string `json:"prefix"`
}
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []Route `json:"routes"`
Services map[string]string `json:"services"`
}
type SmegMesh struct {
MeshId string `json:"meshid"`
Nodes map[string]SmegNode `json:"nodes"`
}
type CreateMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}
type ApiServerConf struct {
WordsFile string
}

View File

@ -18,13 +18,14 @@ import (
// CrdtMeshManager manages nodes in the crdt mesh // CrdtMeshManager manages nodes in the crdt mesh
type CrdtMeshManager struct { type CrdtMeshManager struct {
MeshId string MeshId string
IfName string IfName string
NodeId string Client *wgctrl.Client
Client *wgctrl.Client doc *automerge.Doc
doc *automerge.Doc LastHash automerge.ChangeHash
LastHash automerge.ChangeHash conf *conf.WgMeshConfiguration
conf *conf.WgMeshConfiguration cache *MeshCrdt
lastCacheHash automerge.ChangeHash
} }
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) { func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -34,15 +35,78 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt") panic("node must be of type *MeshNodeCrdt")
} }
crdt.Routes = make(map[string]interface{}) crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix() crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
}
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
keepAliveTime := timestamp.Int64()
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys
} }
// GetMesh(): Converts the document into a struct // GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) { func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root()) changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
} }
// GetMeshId returns the meshid of the mesh // GetMeshId returns the meshid of the mesh
@ -82,13 +146,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName manager.IfName = params.DevName
manager.Client = params.Client manager.Client = params.Client
manager.conf = &params.Conf manager.conf = &params.Conf
manager.cache = nil
return &manager, nil return &manager, nil
} }
// GetNode: returns a mesh node crdt.Close releases resources used by a Client. // NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) { func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint) node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -178,8 +252,74 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err return err
} }
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId // AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId) logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -198,7 +338,10 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
} }
for _, route := range routes { for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{}) err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
HopCount: int64(route.GetHopCount()),
})
if err != nil { if err != nil {
return err return err
@ -207,6 +350,66 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
return nil return nil
} }
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
if !ok || route.GetHopCount() < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
HopCount: int64(route.GetHopCount()) + 1,
}
}
}
}
return routes, nil
}
// DeleteRoutes deletes the specified routes // DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error { func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId) nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
@ -319,8 +522,10 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp return m.Timestamp
} }
func (m *MeshNodeCrdt) GetRoutes() []string { func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.MapKeys(m.Routes) return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &r
})
} }
func (m *MeshNodeCrdt) GetDescription() string { func (m *MeshNodeCrdt) GetDescription() string {
@ -336,6 +541,26 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":") return strings.Join(constituents, ":")
} }
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode { func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode) nodes := make(map[string]mesh.MeshNode)
@ -348,8 +573,20 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp, Timestamp: node.Timestamp,
Routes: node.Routes, Routes: node.Routes,
Description: node.Description, Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
Type: node.Type,
} }
} }
return nodes return nodes
} }
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
func (r *Route) GetHopCount() int {
return int(r.HopCount)
}

View File

@ -28,14 +28,23 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode { func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params) hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{ return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort), HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(), PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort), WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()), WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty. // Always set the routes as empty.
// Routes handled by external component // Routes handled by external component
Routes: map[string]interface{}{}, Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(f.Config.Role),
} }
} }
@ -48,7 +57,19 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else if len(f.Config.Endpoint) != 0 { } else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint hostName = f.Config.Endpoint
} else { } else {
hostName = lib.GetOutboundIP().String() ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
} }
return hostName return hostName

View File

@ -1,14 +1,23 @@
package crdt package crdt
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
HopCount int64 `automerge:"hopCount"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes // MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct { type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"` HostEndpoint string `automerge:"hostEndpoint"`
WgEndpoint string `automerge:"wgEndpoint"` WgEndpoint string `automerge:"wgEndpoint"`
PublicKey string `automerge:"publicKey"` PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"` WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"` Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"` Routes map[string]Route `automerge:"routes"`
Description string `automerge:"description"` Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
} }
// MeshCrdt: Represents the mesh network as a whole // MeshCrdt: Represents the mesh network as a whole

View File

@ -16,6 +16,20 @@ func (m *WgMeshConfigurationError) Error() string {
return m.msg return m.msg
} }
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
)
type WgMeshConfiguration struct { type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS // CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"` CertificatePath string `yaml:"certificatePath"`
@ -28,6 +42,9 @@ type WgMeshConfiguration struct {
SkipCertVerification bool `yaml:"skipCertVerification"` SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on // Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"` GrpcPort string `yaml:"gRPCPort"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
// use public IP discovery library
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes // AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"` AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable. // Endpoint is the IP in which this computer is publicly reachable.
@ -47,8 +64,21 @@ type WgMeshConfiguration struct {
KeepAliveTime int `yaml:"keepAliveTime"` KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead // Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"` Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we consider the 'node' as dead // PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"` PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
} }
func ValidateConfiguration(c *WgMeshConfiguration) error { func ValidateConfiguration(c *WgMeshConfiguration) error {
@ -118,9 +148,15 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
} }
} }
if c.PruneTime <= 1 { if c.PruneTime < 1 {
return &WgMeshConfigurationError{ return &WgMeshConfigurationError{
msg: "Prune time cannot be <= 1", msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
} }
} }
@ -130,6 +166,14 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
} }
} }
if c.Role == "" {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil return nil
} }

View File

@ -31,11 +31,11 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := crdt.MeshNodeFactory{ nodeFactory := crdt.MeshNodeFactory{
Config: *params.Conf, Config: *params.Conf,
} }
idGenerator := &lib.UUIDGenerator{} idGenerator := &lib.IDNameGenerator{}
ipAllocator := &ip.ULABuilder{} ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client) interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer() configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
meshManagerParams := &mesh.NewMeshManagerParams{ meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf, Conf: *params.Conf,

View File

@ -17,6 +17,9 @@ type MeshNode struct {
WgHost string WgHost string
Timestamp int64 Timestamp int64
Routes []string Routes []string
Description string
Alias string
Services map[string]string
} }
// Represents a WireGuard Mesh // Represents a WireGuard Mesh

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

View File

@ -4,13 +4,13 @@ package rpctypes;
option go_package = "pkg/rpc"; option go_package = "pkg/rpc";
service MeshCtrlServer { service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {} rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
} }
message JoinMeshRequest { message GetMeshRequest {
string meshId = 2; string meshId = 1;
} }
message JoinMeshReply { message GetMeshReply {
bool success = 1; bytes mesh = 1;
} }

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -11,8 +11,6 @@ import (
) )
type NewMeshArgs struct { type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
WgPort int WgPort int
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
@ -25,13 +23,19 @@ type JoinMeshArgs struct {
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose // Port is the WireGuard port to expose
Port int Port int
// Endpoint is the routable address of this machine. If not provided // Endpoint is the routable address of this machine. If not provided
// defaults to the default address // defaults to the default address
Endpoint string Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
}
type PutServiceArgs struct {
Service string
Value string
} }
type GetMeshReply struct { type GetMeshReply struct {
@ -47,23 +51,31 @@ type QueryMesh struct {
Query string Query string
} }
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface { type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args JoinMeshArgs, reply *string) error JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error
GetDOT(meshId string, reply *string) error GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error
} }
const SockAddr = "/tmp/wgmesh_ipc.sock" const SockAddr = "/tmp/wgmesh_ipc.sock"
func RunIpcHandler(server MeshIpc) error { func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil { if err := os.RemoveAll(SockAddr); err != nil {
return errors.New("Could not find to address") return errors.New("could not find to address")
} }
rpc.Register(server) rpc.Register(server)

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
bucketFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -1,6 +1,9 @@
package lib package lib
import "github.com/google/uuid" import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
)
// IdGenerator generates unique ids // IdGenerator generates unique ids
type IdGenerator interface { type IdGenerator interface {
@ -15,3 +18,11 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New() id := uuid.New()
return id.String(), nil return id.String(), nil
} }
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -1,17 +1,61 @@
package lib package lib
import ( import (
"encoding/json"
"io"
"log" "log"
"net" "net"
"net/http"
) )
// GetOutboundIP: gets the oubound IP of this packet // GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP { func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80") conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
defer conn.Close() defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr) localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
} }

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

View File

@ -201,7 +201,7 @@ func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
}) })
if err != nil { if err != nil {
return fmt.Errorf("failed to delete route %w", err) return fmt.Errorf("failed to delete route %s", dst.IP.String())
} }
return nil return nil
@ -219,22 +219,15 @@ func (r1 Route) equal(r2 Route) bool {
// DeleteRoutes deletes all routes not in exclude // DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error { func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes := make([]rtnetlink.RouteMessage, 0) routes, err := c.listRoutes(ifName, family)
if len(exclude) != 0 { if err != nil {
lRoutes, err := c.listRoutes(ifName, family, exclude[0].Gateway) return err
if err != nil {
return err
}
routes = lRoutes
} }
ifRoutes := make([]Route, 0) ifRoutes := make([]Route, 0)
for _, rtRoute := range routes { for _, rtRoute := range routes {
logging.Log.WriteInfof("Routes: %s", rtRoute.Attributes.Dst.String())
maskSize := 128 maskSize := 128
if family == unix.AF_INET { if family == unix.AF_INET {
@ -262,7 +255,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
toDelete := Filter(ifRoutes, shouldExclude) toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete { for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route %s", route.Destination.String()) logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
err := c.DeleteRoute(ifName, route) err := c.DeleteRoute(ifName, route)
if err != nil { if err != nil {
@ -274,7 +267,7 @@ func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...R
} }
// listRoutes lists all routes on the interface // listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP) ([]rtnetlink.RouteMessage, error) { func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName) iface, err := net.InterfaceByName(ifName)
if err != nil { if err != nil {
@ -288,7 +281,7 @@ func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8, gateway net.IP
} }
filterFunc := func(r rtnetlink.RouteMessage) bool { filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway.Equal(gateway) && r.Attributes.OutIface == uint32(iface.Index) return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
} }
routes = Filter(routes, filterFunc) routes = Filter(routes, filterFunc)

View File

@ -2,6 +2,7 @@
package logging package logging
import ( import (
"io"
"os" "os"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -15,6 +16,7 @@ type Logger interface {
WriteInfof(msg string, args ...interface{}) WriteInfof(msg string, args ...interface{})
WriteErrorf(msg string, args ...interface{}) WriteErrorf(msg string, args ...interface{})
WriteWarnf(msg string, args ...interface{}) WriteWarnf(msg string, args ...interface{})
Writer() io.Writer
} }
type LogrusLogger struct { type LogrusLogger struct {
@ -33,6 +35,10 @@ func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...) l.logger.Warnf(msg, args...)
} }
func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger() *LogrusLogger { func NewLogrusLogger() *LogrusLogger {
logger := logrus.New() logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -3,7 +3,13 @@ package mesh
import ( import (
"fmt" "fmt"
"net" "net"
"slices"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
@ -16,10 +22,24 @@ type MeshConfigApplyer interface {
// WgMeshConfigApplyer applies WireGuard configuration // WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct { type WgMeshConfigApplyer struct {
meshManager MeshManager meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
} }
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) { type routeNode struct {
gateway string
route Route
}
func (r *routeNode) equals(route2 *routeNode) bool {
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint()) endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil { if err != nil {
@ -35,20 +55,88 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
allowedips := make([]net.IPNet, 1) allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost() allowedips[0] = *node.GetWgHost()
clients, ok := peerToClients[node.GetWgHost().String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() { for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route) bestRoutes := routes[route.GetDestination().String()]
allowedips = append(allowedips, *ipnet)
if len(bestRoutes) == 1 {
allowedips = append(allowedips, *route.GetDestination())
} else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *route.GetDestination())
}
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
if existing != -1 {
endpoint = device.Peers[existing].Endpoint
} }
peerConfig := wgtypes.PeerConfig{ peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey, PublicKey: pubKey,
Endpoint: endpoint, Endpoint: endpoint,
AllowedIPs: allowedips, AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
} }
return &peerConfig, nil return &peerConfig, nil
} }
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(mesh MeshSnapshot) map[string][]routeNode {
routes := make(map[string][]routeNode)
for _, node := range mesh.GetNodes() {
for _, route := range node.GetRoutes() {
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
pubKey, _ := node.GetPublicKey()
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if otherRoute[0].route.GetHopCount() > route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error { func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh() snap, err := mesh.GetMesh()
@ -56,18 +144,67 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err return err
} }
nodes := snap.GetNodes() nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes)) peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0 var count int = 0
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
peerToClients := make(map[string][]net.IPNet)
routes := m.getRoutes(snap)
installedRoutes := make([]lib.Route, 0)
for _, n := range nodes { for _, n := range nodes {
peer, err := convertMeshNode(n) if NodeEquals(n, self) {
continue
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
hashFunc := func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
}
peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc)
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
dev, _ := mesh.GetDevice()
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil { if err != nil {
return err return err
} }
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
peerConfigs[count] = *peer peerConfigs[count] = *peer
count++ count++
} }
@ -82,6 +219,12 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err return err
} }
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
if err != nil {
return err
}
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg) return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
} }
@ -112,7 +255,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{ m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true, ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1), Peers: make([]wgtypes.PeerConfig, 0),
}) })
return nil return nil
@ -122,6 +265,9 @@ func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager m.meshManager = manager
} }
func NewWgMeshConfigApplyer() MeshConfigApplyer { func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
return &WgMeshConfigApplyer{} return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
}
} }

View File

@ -61,7 +61,7 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId strin
self, _ := c.manager.GetSelf(meshId) self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() { if NodeEquals(self, node) {
return return
} }

View File

@ -14,22 +14,27 @@ import (
) )
type MeshManager interface { type MeshManager interface {
CreateMesh(devName string, port int) (string, error) CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error
GetPublicKey(meshId string) (*wgtypes.Key, error) GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error) GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error ApplyConfig() error
SetDescription(description string) error SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error UpdateTimeStamp() error
GetClient() *wgctrl.Client GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider GetMeshes() map[string]MeshProvider
Prune() error Prune() error
Close() error Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
GetRouteManager() RouteManager
} }
type MeshManagerImpl struct { type MeshManagerImpl struct {
@ -46,6 +51,59 @@ type MeshManagerImpl struct {
idGenerator lib.IdGenerator idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
}
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil {
return err
}
}
return nil
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
} }
// Prune implements MeshManager. // Prune implements MeshManager.
@ -62,15 +120,25 @@ func (m *MeshManagerImpl) Prune() error {
} }
// CreateMesh: Creates a new mesh, stores it and returns the mesh id // CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil { if err != nil {
return "", err return "", err
} }
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName, DevName: ifName,
Port: port, Port: port,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -81,30 +149,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
return meshId, nil return meshId, nil
} }
type AddMeshParams struct { type AddMeshParams struct {
MeshId string MeshId string
DevName string
WgPort int WgPort int
MeshBytes []byte MeshBytes []byte
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName, DevName: ifName,
Port: params.WgPort, Port: params.WgPort,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -122,11 +191,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
} }
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
return nil
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
} }
// HasChanges returns true if the mesh has changes // HasChanges returns true if the mesh has changes
@ -140,25 +205,13 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
return theMesh return theMesh
} }
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
err = s.RouteManager.InstallRoutes()
if err != nil {
return err
}
return nil
}
// GetPublicKey: Gets the public key of the WireGuard mesh // GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId] mesh, ok := s.Meshes[meshId]
if !ok { if !ok {
@ -180,7 +233,6 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise // WgPort is the WireGuard port to advertise
WgPort int WgPort int
// Endpoint is the alias of the machine to send routable packets // Endpoint is the alias of the machine to send routable packets
// to
Endpoint string Endpoint string
} }
@ -192,35 +244,43 @@ func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId) return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
} }
pubKey, err := s.GetPublicKey(params.MeshId) if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil { if err != nil {
return err return err
}
params.WgPort = device.ListenPort
} }
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId) pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
if err != nil { if err != nil {
return err return err
} }
node := s.nodeFactory.Build(&MeshNodeFactoryParams{ node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey, PublicKey: &pubKey,
NodeIP: nodeIP, NodeIP: nodeIP,
WgPort: params.WgPort, WgPort: params.WgPort,
Endpoint: params.Endpoint, Endpoint: params.Endpoint,
}) })
device, err := mesh.GetDevice() if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil { if err != nil {
return fmt.Errorf("failed to get device %w", err) return fmt.Errorf("failed to get device %w", err)
} }
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP)) err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil { if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err) return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
} }
s.Meshes[params.MeshId].AddNode(node) s.Meshes[params.MeshId].AddNode(node)
@ -235,19 +295,23 @@ func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
return fmt.Errorf("mesh %s does not exist", meshId) return fmt.Errorf("mesh %s does not exist", meshId)
} }
err := s.RouteManager.RemoveRoutes(meshId) var err error
if err != nil { if !s.conf.StubWg {
return err device, err := mesh.GetDevice()
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
} }
device, err := mesh.GetDevice() err = s.RouteManager.RemoveRoutes(meshId)
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
delete(s.Meshes, meshId) delete(s.Meshes, meshId)
return err return err
} }
@ -259,15 +323,10 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
return nil, fmt.Errorf("mesh %s does not exist", meshId) return nil, fmt.Errorf("mesh %s does not exist", meshId)
} }
snapshot, err := meshInstance.GetMesh() logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil { if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh") return nil, errors.New("the node doesn't exist in the mesh")
} }
@ -275,40 +334,52 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
} }
func (s *MeshManagerImpl) ApplyConfig() error { func (s *MeshManagerImpl) ApplyConfig() error {
if s.conf.StubWg {
return nil
}
err := s.configApplyer.ApplyConfig() err := s.configApplyer.ApplyConfig()
if err != nil { if err != nil {
return err return err
} }
return s.RouteManager.InstallRoutes() return nil
} }
func (s *MeshManagerImpl) SetDescription(description string) error { func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description) if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil { if err != nil {
return err return err
}
} }
} }
return nil return nil
} }
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes // UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error { func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh() if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil {
return err
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
if err != nil { if err != nil {
return err return err
@ -327,7 +398,12 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes return s.Meshes
} }
// Close the mesh manager
func (s *MeshManagerImpl) Close() error { func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes { for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice() dev, err := mesh.GetDevice()
@ -359,18 +435,12 @@ type NewMeshManagerParams struct {
} }
// Creates a new instance of a mesh manager with the given parameters // Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl { func NewMeshManager(params *NewMeshManagerParams) MeshManager {
hostParams := HostParameters{} privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
switch params.Conf.Endpoint { PrivateKey: &privateKey,
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
} }
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManagerImpl{ m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider), Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams, HostParameters: &hostParams,
@ -390,5 +460,11 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
m.idGenerator = params.IdGenerator m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
return m return m
} }

View File

@ -22,7 +22,7 @@ func getMeshConfiguration() *conf.WgMeshConfiguration {
} }
} }
func getMeshManager() *MeshManagerImpl { func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{ manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(), Conf: *getMeshConfiguration(),
Client: nil, Client: nil,
@ -51,7 +51,7 @@ func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
t.Fatal(`meshId should not be empty`) t.Fatal(`meshId should not be empty`)
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if !exists { if !exists {
t.Fatal(`mesh was not created when it should be`) t.Fatal(`mesh was not created when it should be`)
@ -64,7 +64,6 @@ func TestAddMeshAddsAMesh(t *testing.T) {
manager.AddMesh(&AddMeshParams{ manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -83,7 +82,6 @@ func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
for i := 0; i < 2; i++ { for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -106,7 +104,6 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -175,7 +172,6 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{ err := manager.AddMesh(&AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: "wg0",
WgPort: 6000, WgPort: 6000,
MeshBytes: make([]byte, 0), MeshBytes: make([]byte, 0),
}) })
@ -190,7 +186,7 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
t.Fatalf("%s", err.Error()) t.Fatalf("%s", err.Error())
} }
_, exists := manager.Meshes[meshId] _, exists := manager.GetMeshes()[meshId]
if exists { if exists {
t.Fatalf(`expected mesh to have been deleted`) t.Fatalf(`expected mesh to have been deleted`)
@ -201,8 +197,8 @@ func TestSetDescription(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
description := "wooooo" description := "wooooo"
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh("wg0", 5001) meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,
@ -225,8 +221,8 @@ func TestSetDescription(t *testing.T) {
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) { func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager() manager := getMeshManager()
meshId1, _ := manager.CreateMesh("wg0", 5000) meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh("wg0", 5001) meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{ manager.AddSelf(&AddSelfParams{
MeshId: meshId1, MeshId: meshId1,

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

View File

@ -7,19 +7,16 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ip" "github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log" logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
type RouteManager interface { type RouteManager interface {
UpdateRoutes() error UpdateRoutes() error
InstallRoutes() error
RemoveRoutes(meshId string) error RemoveRoutes(meshId string) error
} }
type RouteManagerImpl struct { type RouteManagerImpl struct {
meshManager MeshManager meshManager MeshManager
routeInstaller route.RouteInstaller
} }
func (r *RouteManagerImpl) UpdateRoutes() error { func (r *RouteManagerImpl) UpdateRoutes() error {
@ -27,6 +24,24 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
ulaBuilder := new(ip.ULABuilder) ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes { for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
pubKey, err := self.GetPublicKey()
if err != nil {
return err
}
routes, err := mesh1.GetRoutes(pubKey.String())
if err != nil {
return err
}
for _, mesh2 := range meshes { for _, mesh2 := range meshes {
if mesh1 == mesh2 { if mesh1 == mesh2 {
continue continue
@ -39,13 +54,10 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err return err
} }
self, err := r.meshManager.GetSelf(mesh1.GetMeshId()) err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
Destination: ipNet,
if err != nil { HopCount: 0,
return err })...)
}
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String())
if err != nil { if err != nil {
return err return err
@ -74,7 +86,7 @@ func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
return err return err
} }
mesh1.RemoveRoutes(self.GetHostEndpoint(), ipNet.String()) mesh1.RemoveRoutes(NodeID(self), ipNet.String())
} }
return nil return nil
} }
@ -128,7 +140,11 @@ func (m *RouteManagerImpl) installRoute(ifName string, meshid string, node MeshN
return err return err
} }
routes := lib.Map(append(node.GetRoutes(), ipNet.String()), routeMapFunc) theRoutes := lib.Map(node.GetRoutes(), func(r Route) string {
return r.GetDestination().String()
})
routes := lib.Map(append(theRoutes, ipNet.String()), routeMapFunc)
return m.addRoute(ifName, ipNet.String(), routes...) return m.addRoute(ifName, ipNet.String(), routes...)
} }
@ -152,7 +168,7 @@ func (m *RouteManagerImpl) installRoutes(meshProvider MeshProvider) error {
} }
for _, node := range mesh.GetNodes() { for _, node := range mesh.GetNodes() {
if self.GetHostEndpoint() == node.GetHostEndpoint() { if NodeEquals(self, node) {
continue continue
} }
@ -180,5 +196,5 @@ func (r *RouteManagerImpl) InstallRoutes() error {
} }
func NewRouteManager(m MeshManager) RouteManager { func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()} return &RouteManagerImpl{meshManager: m}
} }

View File

@ -16,11 +16,26 @@ type MeshNodeStub struct {
wgEndpoint string wgEndpoint string
wgHost *net.IPNet wgHost *net.IPNet
timeStamp int64 timeStamp int64
routes []string routes []Route
identifier string identifier string
description string description string
} }
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return conf.PEER_ROLE
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string { func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint return m.hostEndpoint
} }
@ -41,7 +56,7 @@ func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp return m.timeStamp
} }
func (m *MeshNodeStub) GetRoutes() []string { func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes return m.routes
} }
@ -66,9 +81,43 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub snapshot *MeshSnapshotStub
} }
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
return nil, nil
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
return false
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
return nil
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
return nil
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
}
// RemoveRoutes implements MeshProvider. // RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error { func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
panic("unimplemented") return nil
} }
// Prune implements MeshProvider. // Prune implements MeshProvider.
@ -114,7 +163,7 @@ func (s *MeshProviderStub) HasChanges() bool {
return false return false
} }
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error { func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil return nil
} }
@ -148,7 +197,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort), wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost, wgHost: wgHost,
timeStamp: time.Now().Unix(), timeStamp: time.Now().Unix(),
routes: make([]string, 0), routes: make([]Route, 0),
identifier: "abc", identifier: "abc",
description: "A Mesh Node Stub", description: "A Mesh Node Stub",
} }
@ -171,6 +220,36 @@ type MeshManagerStub struct {
meshes map[string]MeshProvider meshes map[string]MeshProvider
} }
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager. // Close implements MeshManager.
func (*MeshManagerStub) Close() error { func (*MeshManagerStub) Close() error {
panic("unimplemented") panic("unimplemented")
@ -185,7 +264,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)} return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
} }
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil return "tim123", nil
} }
@ -208,10 +287,6 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}} snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
} }
func (m *MeshManagerStub) EnableInterface(meshId string) error {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) { func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey() key, _ := wgtypes.GenerateKey()
return &key, nil return &key, nil

View File

@ -10,6 +10,26 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
}
// MeshNode represents an implementation of a node in a mesh // MeshNode represents an implementation of a node in a mesh
type MeshNode interface { type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node // GetHostEndpoint: gets the gRPC endpoint of the node
@ -23,11 +43,35 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe // GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64 GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides // GetRoutes: returns the routes that the nodes provides
GetRoutes() []string GetRoutes() []Route
// GetIdentifier: returns the identifier of the node // GetIdentifier: returns the identifier of the node
GetIdentifier() string GetIdentifier() string
// GetDescription: returns the description for this node // GetDescription: returns the description for this node
GetDescription() string GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
GetType() conf.NodeType
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
return key1.String() == key2.String()
}
func RouteEquals(route1, route2 Route) bool {
return route1.GetDestination().String() == route2.GetDestination().String() &&
route1.GetHopCount() == route2.GetHopCount()
}
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()
} }
type MeshSnapshot interface { type MeshSnapshot interface {
@ -46,7 +90,7 @@ type MeshSyncer interface {
type MeshProvider interface { type MeshProvider interface {
// AddNode() adds a node to the mesh // AddNode() adds a node to the mesh
AddNode(node MeshNode) AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider // GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error) GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network // GetMeshId() returns the ID of the mesh network
GetMeshId() string GetMeshId() string
@ -63,21 +107,40 @@ type MeshProvider interface {
// UpdateTimeStamp: update the timestamp of the given node // UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node // AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node // DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync // GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type // SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in // Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds // pruneAmount seconds
Prune(pruneAmount int) error Prune(pruneAmount int) error
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
} }
// HostParameters contains the IDs of a node // HostParameters contains the IDs of a node
type HostParameters struct { type HostParameters struct {
HostEndpoint string PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
} }
// MeshProviderFactoryParams parameters required to build a mesh provider // MeshProviderFactoryParams parameters required to build a mesh provider

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
) )
@ -23,14 +24,22 @@ type QueryError struct {
msg string msg string
} }
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
}
type QueryNode struct { type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"` HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"` WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"` WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"` Timestamp int64 `json:"timestamp"`
Description string `json:"description"` Description string `json:"description"`
Routes []string `json:"routes"` Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
} }
func (m *QueryError) Error() string { func (m *QueryError) Error() string {
@ -51,7 +60,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err return nil, err
} }
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode) nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes) result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +72,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err return bytes, err
} }
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode { func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode) queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint() queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey() pubKey, _ := node.GetPublicKey()
@ -74,8 +83,17 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String() queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp() queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes() queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
}
})
queryNode.Description = node.GetDescription() queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode return queryNode
} }

View File

@ -2,6 +2,7 @@ package robin
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"strconv" "strconv"
@ -9,7 +10,9 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc" "github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh" "github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc" "github.com/tim-beatham/wgmesh/pkg/rpc"
) )
@ -18,7 +21,7 @@ type IpcHandler struct {
} }
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort) meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil { if err != nil {
return err return err
@ -70,7 +73,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err return err
} }
ctx, cancel := context.WithTimeout(context.Background(), time.Second) configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel() defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -81,7 +86,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId, MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port, WgPort: args.Port,
MeshBytes: meshReply.Mesh, MeshBytes: meshReply.Mesh,
}) })
@ -116,14 +120,19 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
} }
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error { func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId) theMesh := n.Server.GetMeshManager().GetMesh(meshId)
meshSnapshot, err := mesh.GetMesh()
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := theMesh.GetMesh()
if err != nil { if err != nil {
return err return err
} }
if mesh == nil { if theMesh == nil {
return errors.New("mesh does not exist") return errors.New("mesh does not exist")
} }
@ -143,7 +152,12 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
PublicKey: pubKey.String(), PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(), WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(), Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(), Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) string {
return r.GetDestination().String()
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
} }
nodes[i] = node nodes[i] = node
@ -154,18 +168,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil return nil
} }
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error { func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager()) g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
@ -201,6 +203,60 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil return nil
} }
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct { type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer CtrlServer ctrlserver.CtrlServer
} }

View File

@ -28,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil return &reply, nil
} }
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1,22 +1,32 @@
package route package route
import ( import (
"net" "github.com/tim-beatham/wgmesh/pkg/lib"
"os/exec" "golang.org/x/sys/unix"
logging "github.com/tim-beatham/wgmesh/pkg/log"
) )
type RouteInstaller interface { type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error InstallRoutes(devName string, routes ...lib.Route) error
} }
type RouteInstallerImpl struct{} type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table // InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error { func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes { for _, route := range routes {
err := r.installRoute(devName, route) err := rtnl.AddRoute(devName, route)
if err != nil { if err != nil {
return err return err
@ -26,22 +36,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil return nil
} }
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller { func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{} return &RouteInstallerImpl{}
} }

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
) )
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct { type GetMeshRequest struct {
state protoimpl.MessageState state protoimpl.MessageState
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() { func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{} *x = GetMeshRequest{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {} func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message { func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) { func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
} }
func (x *GetMeshRequest) GetMeshId() string { func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() { func (x *GetMeshReply) Reset() {
*x = GetMeshReply{} *x = GetMeshReply{}
if protoimpl.UnsafeEnabled { if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {} func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message { func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2] mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil { if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead. // Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) { func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2} return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
} }
func (x *GetMeshReply) GetMesh() []byte { func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil return nil
} }
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e, 0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e, 0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x6f, 0x33,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
} }
var ( var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode (*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest (*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
} }
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{ var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest 0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest 1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply 1, // [1:2] is the sub-list for method output_type
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply 0, // [0:1] is the sub-list for method input_type
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name 0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
} }
if !protoimpl.UnsafeEnabled { if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i { switch v := v.(*GetMeshRequest); i {
case 0: case 0:
return &v.state return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i { switch v := v.(*GetMeshReply); i {
case 0: case 0:
return &v.state return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil return nil
} }
} }
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
} }
type x struct{} type x struct{}
out := protoimpl.TypeBuilder{ out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc, RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0, NumEnums: 0,
NumMessages: 5, NumMessages: 2,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface { type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error) GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
} }
type meshCtrlServerClient struct { type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil return out, nil
} }
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service. // MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer // All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility // for forward compatibility
type MeshCtrlServerServer interface { type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer() mustEmbedUnimplementedMeshCtrlServerServer()
} }
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) { func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
} }
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {} func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service. // UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service. // MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh", MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler, Handler: _MeshCtrlServer_GetMesh_Handler,
}, },
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
}, },
Streams: []grpc.StreamDesc{}, Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto", Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -1,7 +1,6 @@
package sync package sync
import ( import (
"errors"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"
@ -30,54 +29,35 @@ type SyncerImpl struct {
// Sync: Sync random nodes // Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error { func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
err := s.manager.ApplyConfig()
if err != nil {
logging.Log.WriteInfof("Failed to update config %w", err)
}
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 { if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId) logging.Log.WriteInfof("No changes for %s", meshId)
return nil return nil
} }
theMesh := s.manager.GetMesh(meshId) logging.Log.WriteInfof("UPDATING WG CONF")
if theMesh == nil { if s.manager.HasChanges(meshId) {
return errors.New("the provided mesh does not exist") err := s.manager.ApplyConfig()
}
if err != nil {
snapshot, err := theMesh.GetMesh() logging.Log.WriteInfof("Failed to update config %w", err)
}
if err != nil {
return err
}
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
} }
nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId) self, err := s.manager.GetSelf(meshId)
if err != nil { if err != nil {
return err return err
} }
excludedNodes := map[string]struct{}{ selfPublickey, err := self.GetPublicKey()
self.GetHostEndpoint(): {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
getNames := func(node mesh.MeshNode) string { if err != nil {
return node.GetHostEndpoint() return err
} }
nodeNames := lib.Map(meshNodes, getNames) neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String())
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate) randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset { for _, node := range randomSubset {
@ -86,9 +66,9 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now() before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance { if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster") logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint()) interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String())
randomSubset = append(randomSubset, interCluster) randomSubset = append(randomSubset, interCluster)
} }
@ -99,7 +79,14 @@ func (s *SyncerImpl) Sync(meshId string) error {
go func(i int) error { go func(i int) error {
defer waitGroup.Done() defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, randomSubset[i])
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
}
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
return err return err
}(index) }(index)
} }
@ -107,16 +94,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
waitGroup.Wait() waitGroup.Wait()
s.syncCount++ s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before)) logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount) logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount) s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
return nil return nil
} }
// SyncMeshes: Sync all meshes // SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error { func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.GetMeshes() { for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId) err := s.Sync(meshId)
if err != nil { if err != nil {

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{ err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: ifName,
WgPort: port, WgPort: port,
MeshBytes: reply.Mesh, MeshBytes: reply.Mesh,
}) })

View File

@ -5,20 +5,6 @@ import (
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
) )
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler. // Run implements SyncScheduler.
func syncFunction(syncer Syncer) lib.TimerFunc { func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error { return func() error {

View File

@ -64,11 +64,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer() syncer = mesh.GetSyncer()
} else if meshId != in.MeshId { } else if meshId != in.MeshId {
return errors.New("Differing MeshIDs") return errors.New("differing meshids")
} }
if syncer == nil { if syncer == nil {
return errors.New("Syncer should not be nil") return errors.New("syncer should not be nil")
} }
msg, moreMessages := syncer.GenerateMessage() msg, moreMessages := syncer.GenerateMessage()

View File

@ -1,4 +1,4 @@
package timestamp package timer
import ( import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver" "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
@ -12,3 +12,11 @@ func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime) return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
} }
func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes()
}
return *lib.NewTimer(timerFunc, 10)
}

View File

@ -2,8 +2,8 @@ package wg
type WgInterfaceManipulatorStub struct{} type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error { func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return nil return "", nil
} }
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error { func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {

View File

@ -1,5 +1,7 @@
package wg package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct { type WgError struct {
msg string msg string
} }
@ -8,14 +10,9 @@ func (m *WgError) Error() string {
return m.msg return m.msg
} }
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface { type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name // AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface // RemoveInterface removes the specified interface

View File

@ -1,6 +1,9 @@
package wg package wg
import ( import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
@ -13,40 +16,48 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client client *wgctrl.Client
} }
const hashLength = 6
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig() rtnl, err := lib.NewRtNetlinkConfig()
if err != nil { if err != nil {
return fmt.Errorf("failed to access link: %w", err) return "", fmt.Errorf("failed to access link: %w", err)
} }
defer rtnl.Close() defer rtnl.Close()
err = rtnl.CreateLink(params.IfName) randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed to create link: %w", err) return "", err
} }
privateKey, err := wgtypes.GeneratePrivateKey() md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil { if err != nil {
return fmt.Errorf("failed to create private key: %w", err) return "", fmt.Errorf("failed to create link: %w", err)
} }
var cfg wgtypes.Config = wgtypes.Config{ var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey, PrivateKey: privKey,
ListenPort: &params.Port, ListenPort: &port,
} }
err = m.client.ConfigureDevice(params.IfName, cfg) err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure dev: %w", err) m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
} }
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return nil return md5Str, nil
} }
// Add an address to the given interface // Add an address to the given interface

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98