Compare commits

...

43 Commits

Author SHA1 Message Date
d8e156f13f 36-add-route-path-into-route-object
Added the route path into the route object so that we can
see what meshes packets are routed across.
2023-11-27 18:55:41 +00:00
3fca49a1c9 Merge pull request #35 from tim-beatham/34-fix-routing
34 fix routing
2023-11-27 16:05:06 +00:00
a2517a1e72 34-fix-routing
- Added mesh-to-mesh routing of hop count > 1
- If there is a tie-breaker with respect to the hop-count use consistent
hashing to determine the route to take based on the public key.
2023-11-27 15:56:30 +00:00
aef8b59f22 32-fix-routing
Flooding routes into other meshes a bit like BGP.
2023-11-25 03:15:58 +00:00
4030d17b41 Fixed routing issue 2023-11-24 17:49:06 +00:00
73db65660b Merge pull request #33 from tim-beatham/32-incorporate-dns
32-incorporate-dns
2023-11-24 15:05:40 +00:00
d1a74a7b95 32-incorporate-dns
Incorporated a DNS server. A DNS server can be run to resolve host
names.
2023-11-24 15:04:07 +00:00
f28ed8260d Merge pull request #30 from tim-beatham/29-only-ping-clients-who-have-updated-their-config
29-only-ping-clients-who-have-updated-their-config
2023-11-24 12:39:14 +00:00
2c406718df 29-only-ping-clients-who-have-updated-their-config
Only consider clients who have updated their config when synchronising
with peers. Consider a dead time where we don't have a handshake and
a prune time when we remove them from the WireGuard configuration.
2023-11-24 12:37:54 +00:00
11b003b549 Merge pull request #28 from tim-beatham/27-remove-client-grpc-endpoint
27-remove-client-grpc-endpoint
2023-11-24 12:08:42 +00:00
7be11dbaa3 27-remove-client-grpc-endpoint
Removed a client's grpc endpoint value. Client's aren't publicly
available so there is no need for a client's gRPC endpoint.
Also changed a node ID's to their public key. A node id's public
address is an issue for mobility of clients as their endpoint
is subject to change
2023-11-24 12:07:03 +00:00
e7ac8c5542 Only updating WireGuard config if node exists 2023-11-22 13:08:02 +00:00
09c64c4628 Fixed container file 2023-11-22 12:45:01 +00:00
2c4f18f52b Merge pull request #26 from tim-beatham/25-modify-code-to-use-public-api
25-modify-code-to-use-public-api
2023-11-22 10:42:48 +00:00
4c54022f63 25-modify-code-to-use-public-api
Modify the code to use a public IP address by default if none is
specified
2023-11-22 10:41:54 +00:00
bf0724f6e5 Merge pull request #24 from tim-beatham/24-keepalive-holepunch
24 keepalive holepunch
2023-11-21 21:28:16 +00:00
624bd6e921 24-keepalive
Persistent keep alive working
2023-11-21 21:26:31 +00:00
7b939e0468 24-keepalive-holepunch
Added the ability to hole punch NAT
2023-11-21 20:42:43 +00:00
6e201ebaf5 24-keepalive-holepunch
Nodes acting as peers and nodes acting as clients
2023-11-21 16:42:49 +00:00
06542da03c main
Fixed problems with timestamp not updating
2023-11-21 13:31:34 +00:00
0d63cd6624 main
Adding words.txt for what words
2023-11-20 18:12:58 +00:00
f13319cfc1 Merge pull request #22 from tim-beatham/21-phonetic-words-ipv6
21 phonetic words ipv6
2023-11-20 18:08:49 +00:00
95f4495b0b 21-phonetic-words-ipv6
Simple what 8 words implementation
2023-11-20 18:07:52 +00:00
330fa74ef4 IPv6 What 8 Words
what 8 words for ipv6 started
2023-11-20 15:22:32 +00:00
3e5b57e41f Merge pull request #20 from tim-beatham/19-hash-wg-interface
Hashing the WireGuard interface
2023-11-20 13:04:19 +00:00
b179cd3cf4 Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
2023-11-20 13:03:42 +00:00
8f211aa116 Merge pull request #18 from tim-beatham/26-performance-testing
Stubbing out WireGuard components
2023-11-20 11:29:37 +00:00
388153e706 Stubbing out WireGuard components
Stubbing our WireGuard components so that I can use docker/podman
network_mode=host. This is much more efficient than the docker/podman
userspace network.
2023-11-20 11:28:12 +00:00
023565d985 Merge pull request #17 from tim-beatham/25-ability-to-aliases
25 ability to aliases
2023-11-17 22:20:57 +00:00
36c264b38e 25-ability-aliases
Fixed unit tests failing
2023-11-17 22:18:53 +00:00
68db795f47 Ability to specify aliases
Ability to specify aliases that automatically append to /etc/hosts
2023-11-17 22:13:51 +00:00
f6160fe138 Adding aliases that automatically gets added 2023-11-17 19:13:20 +00:00
2c5289afb0 Merge pull request #16 from tim-beatham/15-add-rest-api
Developed a rest API
2023-11-15 12:57:05 +00:00
7199d07a76 Added smegmesh submodule 2023-11-13 10:46:52 +00:00
5f176e731f Developed a rest API 2023-11-13 10:44:14 +00:00
44f119b45c Updating examples 2023-11-08 09:19:24 +00:00
5215d5d54d Merge pull request #14 from tim-beatham/13-netlink-api
Removed interface manipulation via os.Exec into
2023-11-07 19:53:39 +00:00
1a864b7c80 Removed interface manipulation via os.Exec into
rtnetlink calls
2023-11-07 19:48:53 +00:00
4c19ebd81f Merge pull request #12 from tim-beatham/11-health-system
11 health system
2023-11-06 13:40:04 +00:00
acbeb689b5 Prune nodes if they exceed their timeout time 2023-11-06 13:37:28 +00:00
bc6cd4fdd5 Modified syncer 2023-11-06 10:05:23 +00:00
c88012cf71 Added health system to count how many times a node
fails to conenct.
2023-11-06 09:54:06 +00:00
4dc85f3861 Merge pull request #10 from tim-beatham/9-add-ci-support
9 add ci support
2023-11-05 18:07:52 +00:00
64 changed files with 3362 additions and 798 deletions

3
.gitmodules vendored Normal file
View File

@ -0,0 +1,3 @@
[submodule "smegmesh-web"]
path = smegmesh-web
url = git@github.com:tim-beatham/smegmesh-web.git

12
Containerfile Normal file
View File

@ -0,0 +1,12 @@
FROM docker.io/library/golang:bookworm
COPY ./ /wgmesh
RUN apt-get update && apt-get install -y \
wireguard \
wireguard-tools \
iproute2 \
iputils-ping \
tmux \
vim
WORKDIR /wgmesh
RUN go mod tidy
RUN go build -o /usr/local/bin ./...

1
Dockerfile Symbolic link
View File

@ -0,0 +1 @@
Containerfile

19
cmd/api/main.go Normal file
View File

@ -0,0 +1,19 @@
package main
import (
"log"
"github.com/tim-beatham/wgmesh/pkg/api"
)
func main() {
apiServer, err := api.NewSmegServer(api.ApiServerConf{
WordsFile: "./cmd/api/words.txt",
})
if err != nil {
log.Fatal(err.Error())
}
apiServer.Run(":8080")
}

257
cmd/api/words.txt Normal file
View File

@ -0,0 +1,257 @@
be
to
of
it
in
we
do
he
on
go
at
if
or
up
by
hi
the
and
you
not
for
but
say
get
she
one
all
can
out
who
now
see
way
how
lot
yes
use
any
day
try
put
let
why
new
off
big
too
ask
man
bit
end
may
own
run
pay
job
old
kid
bad
few
ago
far
buy
set
guy
car
sit
war
win
yet
top
law
cut
low
die
eat
age
hit
air
add
boy
act
tax
oil
eye
son
key
fun
dad
dog
arm
fly
box
gas
lie
hot
gun
per
art
red
fit
bed
fan
mix
mom
sex
bus
fix
bar
lay
ice
bet
bag
due
aid
tie
leg
ban
odd
cup
dry
cry
rid
pop
sir
cat
map
sad
sea
aim
sun
fat
row
egg
tea
god
wed
tip
ear
hat
net
ill
dig
fee
mad
gap
nor
bid
era
toy
sky
bin
owe
wet
tap
pro
ski
cow
pen
van
web
pot
sum
cap
log
pub
pig
joy
raw
rat
via
lip
two
six
ten
lab
ton
mid
bat
hip
gut
sin
non
rub
sub
par
pre
ray
cue
dye
fin
ion
neo
hey
wow
mum
bye
aye
jet
sue
pet
flu
cop
ooh
rip
spy
pie
bug
gum
wan
rap
nut
beg
pin
pit
jam
tag
fax
vet
fry
pad
lad
mud
bay
con
pan
gee
toe
dip
shy
gym
zoo
fox
bow
tin
hop
wee
kit
opt
vow
sew
cab
bee
rob
rig
yep
ego
rib
nod
hug
lap
ash
hum
dam
bum
yen
jar

18
cmd/dns/main.go Normal file
View File

@ -0,0 +1,18 @@
package main
import (
"log"
smegdns "github.com/tim-beatham/wgmesh/pkg/dns"
)
func main() {
server, err := smegdns.NewDns(53)
if err != nil {
log.Fatal(err.Error())
}
defer server.Close()
server.Listen()
}

View File

@ -8,7 +8,9 @@ import (
"time"
"github.com/akamensky/argparse"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
@ -16,7 +18,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct {
Client *ipcRpc.Client
IfName string
WgPort int
Endpoint string
}
@ -24,7 +25,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string {
var reply string
newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort,
Endpoint: args.Endpoint,
}
@ -68,7 +68,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{
MeshId: params.MeshId,
IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort,
}
@ -96,9 +95,13 @@ func getMesh(client *ipcRpc.Client, meshId string) {
fmt.Println("Control Endpoint: " + node.HostEndpoint)
fmt.Println("WireGuard Endpoint: " + node.WgEndpoint)
fmt.Println("Wg IP: " + node.WgHost)
fmt.Println(fmt.Sprintf("Timestamp: %s", time.Unix(node.Timestamp, 0).String()))
fmt.Printf("Timestamp: %s", time.Unix(node.Timestamp, 0).String())
advertiseRoutes := strings.Join(node.Routes, ",")
mapFunc := func(r ctrlserver.MeshRoute) string {
return r.Destination
}
advertiseRoutes := strings.Join(lib.Map(node.Routes, mapFunc), ",")
fmt.Printf("Routes: %s\n", advertiseRoutes)
fmt.Println("---")
@ -171,6 +174,68 @@ func putDescription(client *ipcRpc.Client, description string) {
fmt.Println(reply)
}
// putAlias: puts an alias for the node
func putAlias(client *ipcRpc.Client, alias string) {
var reply string
err := client.Call("IpcHandler.PutAlias", &alias, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func setService(client *ipcRpc.Client, service, value string) {
var reply string
serviceArgs := &ipc.PutServiceArgs{
Service: service,
Value: value,
}
err := client.Call("IpcHandler.PutService", serviceArgs, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func deleteService(client *ipcRpc.Client, service string) {
var reply string
err := client.Call("IpcHandler.PutService", &service, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func getNode(client *ipcRpc.Client, nodeId, meshId string) {
var reply string
args := &ipc.GetNodeArgs{
NodeId: nodeId,
MeshId: meshId,
}
err := client.Call("IpcHandler.GetNode", &args, &reply)
if err != nil {
fmt.Println(err.Error())
return
}
fmt.Println(reply)
}
func main() {
parser := argparse.NewParser("wg-mesh",
"wg-mesh Manipulate WireGuard meshes")
@ -178,25 +243,24 @@ func main() {
newMeshCmd := parser.NewCommand("new-mesh", "Create a new mesh")
listMeshCmd := parser.NewCommand("list-meshes", "List meshes the node is connected to")
joinMeshCmd := parser.NewCommand("join-mesh", "Join a mesh network")
// getMeshCmd := parser.NewCommand("get-mesh", "Get a mesh network")
enableInterfaceCmd := parser.NewCommand("enable-interface", "Enable A Specific Mesh Interface")
getGraphCmd := parser.NewCommand("get-graph", "Convert a mesh into DOT format")
leaveMeshCmd := parser.NewCommand("leave-mesh", "Leave a mesh network")
queryMeshCmd := parser.NewCommand("query-mesh", "Query a mesh network using JMESPath")
putDescriptionCmd := parser.NewCommand("put-description", "Place a description for the node")
putAliasCmd := parser.NewCommand("put-alias", "Place an alias for the node")
setServiceCmd := parser.NewCommand("set-service", "Place a service into your advertisements")
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
// var getMeshId *string = getMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
var getGraphMeshId *string = getGraphCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -208,6 +272,16 @@ func main() {
var description *string = putDescriptionCmd.String("d", "description", &argparse.Options{Required: true})
var alias *string = putAliasCmd.String("a", "alias", &argparse.Options{Required: true})
var serviceKey *string = setServiceCmd.String("s", "service", &argparse.Options{Required: true})
var serviceValue *string = setServiceCmd.String("v", "value", &argparse.Options{Required: true})
var deleteServiceKey *string = deleteServiceCmd.String("s", "service", &argparse.Options{Required: true})
var getNodeNodeId *string = getNodeCmd.String("n", "nodeid", &argparse.Options{Required: true})
var getNodeMeshId *string = getNodeCmd.String("m", "meshid", &argparse.Options{Required: true})
err := parser.Parse(os.Args)
if err != nil {
@ -224,7 +298,6 @@ func main() {
if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{
Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint,
}))
@ -237,7 +310,6 @@ func main() {
if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{
Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId,
@ -245,10 +317,6 @@ func main() {
}))
}
// if getMeshCmd.Happened() {
// getMesh(client, *getMeshId)
// }
if getGraphCmd.Happened() {
getGraph(client, *getGraphMeshId)
}
@ -268,4 +336,20 @@ func main() {
if putDescriptionCmd.Happened() {
putDescription(client, *description)
}
if putAliasCmd.Happened() {
putAlias(client, *alias)
}
if setServiceCmd.Happened() {
setService(client, *serviceKey, *serviceValue)
}
if deleteServiceCmd.Happened() {
deleteService(client, *deleteServiceKey)
}
if getNodeCmd.Happened() {
getNode(client, *getNodeNodeId, *getNodeMeshId)
}
}

View File

@ -2,6 +2,7 @@ certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
@ -9,4 +10,5 @@ syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveRate: 60
keepAliveTime: 10
pruneTime: 20

View File

@ -1,7 +1,8 @@
package main
import (
"log"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
@ -9,9 +10,10 @@ import (
ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/robin"
"github.com/tim-beatham/wgmesh/pkg/sync"
"github.com/tim-beatham/wgmesh/pkg/timestamp"
timer "github.com/tim-beatham/wgmesh/pkg/timers"
"golang.zx2c4.com/wireguard/wgctrl"
)
@ -34,6 +36,12 @@ func main() {
return
}
if conf.Profile {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}
var robinRpc robin.WgRpc
var robinIpc robin.IpcHandler
var syncProvider sync.SyncServiceImpl
@ -49,7 +57,9 @@ func main() {
syncProvider.Server = ctrlServer
syncRequester := sync.NewSyncRequester(ctrlServer)
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester)
timestampScheduler := timestamp.NewTimestampScheduler(ctrlServer)
timestampScheduler := timer.NewTimestampScheduler(ctrlServer)
pruneScheduler := mesh.NewPruner(ctrlServer.MeshManager, *conf)
routeScheduler := timer.NewRouteScheduler(ctrlServer)
robinIpcParams := robin.RobinIpcParams{
CtrlServer: ctrlServer,
@ -63,11 +73,13 @@ func main() {
return
}
log.Println("Running IPC Handler")
logging.Log.WriteInfof("Running IPC Handler")
go ipc.RunIpcHandler(&robinIpc)
go syncScheduler.Run()
go timestampScheduler.Run()
go pruneScheduler.Run()
go routeScheduler.Run()
closeResources := func() {
logging.Log.WriteInfof("Closing resources")

View File

@ -0,0 +1,95 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
net-2:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.155.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-4:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
sysctls:
- net.ipv6.conf.all.forwarding=1
networks:
- net-1
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-5:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-6:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-7:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-2
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

View File

@ -0,0 +1,42 @@
version: '3'
networks:
net-1:
driver: bridge
ipam:
driver: default
config:
- subnet: 10.89.0.0/17
services:
wg-1:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-2:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"
wg-3:
image: wg-mesh-base:latest
cap_add:
- NET_ADMIN
- NET_RAW
tty: true
networks:
- net-1
volumes:
- ./shared:/shared
command: "wgmeshd /shared/configuration.yaml"

View File

@ -0,0 +1,14 @@
certificatePath: "/wgmesh/cert/cert.pem"
privateKeyPath: "/wgmesh/cert/priv.pem"
caCertificatePath: "/wgmesh/cert/cacert.pem"
skipCertVerification: true
timeout: 5
gRPCPort: "21906"
advertiseRoutes: true
clusterSize: 32
syncRate: 1
interClusterChance: 0.15
branchRate: 3
infectionCount: 3
keepAliveTime: 10
pruneTime: 20

22
go.mod
View File

@ -5,9 +5,12 @@ go 1.21.3
require (
github.com/akamensky/argparse v1.4.0
github.com/automerge/automerge-go v0.0.0-20230903201930-b80ce8aadbb9
github.com/gin-gonic/gin v1.9.1
github.com/google/uuid v1.3.0
github.com/jmespath/go-jmespath v0.4.0
github.com/jsimonetti/rtnetlink v1.3.5
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.14.0
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
google.golang.org/grpc v1.58.1
google.golang.org/protobuf v1.31.0
@ -15,16 +18,33 @@ require (
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/go-cmp v0.5.9 // indirect
github.com/josharian/native v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mdlayher/genetlink v1.3.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.13.0 // indirect
golang.org/x/net v0.15.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20230711160842-782d3b101e98 // indirect

237
pkg/api/apiserver.go Normal file
View File

@ -0,0 +1,237 @@
package api
import (
"fmt"
"net/http"
ipcRpc "net/rpc"
"github.com/gin-gonic/gin"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/what8words"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
type ApiServer interface {
GetMeshes(c *gin.Context)
Run(addr string) error
}
type SmegServer struct {
router *gin.Engine
client *ipcRpc.Client
words *what8words.What8Words
}
func (s *SmegServer) routeToApiRoute(meshNode ctrlserver.MeshNode) []Route {
routes := make([]Route, len(meshNode.Routes))
for index, route := range meshNode.Routes {
if route.Path == nil {
route.Path = make([]string, 0)
}
routes[index] = Route{
Prefix: route.Destination,
Path: route.Path,
}
}
return routes
}
func (s *SmegServer) meshNodeToAPIMeshNode(meshNode ctrlserver.MeshNode) *SmegNode {
if meshNode.Routes == nil {
meshNode.Routes = make([]ctrlserver.MeshRoute, 0)
}
alias := meshNode.Alias
if alias == "" {
alias, _ = s.words.ConvertIdentifier(meshNode.WgHost)
}
return &SmegNode{
WgHost: meshNode.WgHost,
WgEndpoint: meshNode.WgEndpoint,
Endpoint: meshNode.HostEndpoint,
Timestamp: int(meshNode.Timestamp),
Description: meshNode.Description,
Routes: s.routeToApiRoute(meshNode),
PublicKey: meshNode.PublicKey,
Alias: alias,
Services: meshNode.Services,
}
}
func (s *SmegServer) meshToAPIMesh(meshId string, nodes []ctrlserver.MeshNode) SmegMesh {
var smegMesh SmegMesh
smegMesh.MeshId = meshId
smegMesh.Nodes = make(map[string]SmegNode)
for _, node := range nodes {
smegMesh.Nodes[node.WgHost] = *s.meshNodeToAPIMeshNode(node)
}
return smegMesh
}
// CreateMesh: creates a mesh network
func (s *SmegServer) CreateMesh(c *gin.Context) {
var createMesh CreateMeshRequest
if err := c.ShouldBindJSON(&createMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.NewMeshArgs{
WgPort: createMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.CreateMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"meshid": reply,
})
}
// JoinMesh: joins a mesh network
func (s *SmegServer) JoinMesh(c *gin.Context) {
var joinMesh JoinMeshRequest
if err := c.ShouldBindJSON(&joinMesh); err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap,
Port: joinMesh.WgPort,
}
var reply string
err := s.client.Call("IpcHandler.JoinMesh", &ipcRequest, &reply)
if err != nil {
c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, &gin.H{
"status": "success",
})
}
// GetMesh: given a meshId returns the corresponding mesh
// network.
func (s *SmegServer) GetMesh(c *gin.Context) {
meshidParam := c.Param("meshid")
var meshid string = meshidParam
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &meshid, &getMeshReply)
if err != nil {
c.JSON(http.StatusNotFound,
&gin.H{
"error": fmt.Sprintf("could not find mesh %s", meshidParam),
})
return
}
mesh := s.meshToAPIMesh(meshidParam, getMeshReply.Nodes)
c.JSON(http.StatusOK, mesh)
}
func (s *SmegServer) GetMeshes(c *gin.Context) {
listMeshesReply := new(ipc.ListMeshReply)
err := s.client.Call("IpcHandler.ListMeshes", "", &listMeshesReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes := make([]SmegMesh, 0)
for _, mesh := range listMeshesReply.Meshes {
getMeshReply := new(ipc.GetMeshReply)
err := s.client.Call("IpcHandler.GetMesh", &mesh, &getMeshReply)
if err != nil {
logging.Log.WriteErrorf(err.Error())
c.JSON(http.StatusBadRequest, nil)
return
}
meshes = append(meshes, s.meshToAPIMesh(mesh, getMeshReply.Nodes))
}
c.JSON(http.StatusOK, meshes)
}
func (s *SmegServer) Run(addr string) error {
logging.Log.WriteInfof("Running API server")
return s.router.Run(addr)
}
func NewSmegServer(conf ApiServerConf) (ApiServer, error) {
client, err := ipcRpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
words, err := what8words.NewWhat8Words(conf.WordsFile)
if err != nil {
return nil, err
}
router := gin.Default()
router.Use(gin.LoggerWithConfig(gin.LoggerConfig{
Output: logging.Log.Writer(),
}))
smegServer := &SmegServer{
router: router,
client: client,
words: words,
}
router.GET("/meshes", smegServer.GetMeshes)
router.GET("/mesh/:meshid", smegServer.GetMesh)
router.POST("/mesh/create", smegServer.CreateMesh)
router.POST("/mesh/join", smegServer.JoinMesh)
return smegServer, nil
}

37
pkg/api/types.go Normal file
View File

@ -0,0 +1,37 @@
package api
type Route struct {
Prefix string `json:"prefix"`
Path []string `json:"path"`
}
type SmegNode struct {
Alias string `json:"alias"`
WgHost string `json:"wgHost"`
WgEndpoint string `json:"wgEndpoint"`
Endpoint string `json:"endpoint"`
Timestamp int `json:"timestamp"`
Description string `json:"description"`
PublicKey string `json:"publicKey"`
Routes []Route `json:"routes"`
Services map[string]string `json:"services"`
}
type SmegMesh struct {
MeshId string `json:"meshid"`
Nodes map[string]SmegNode `json:"nodes"`
}
type CreateMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
}
type JoinMeshRequest struct {
WgPort int `json:"port" binding:"omitempty,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"`
}
type ApiServerConf struct {
WordsFile string
}

View File

@ -20,11 +20,12 @@ import (
type CrdtMeshManager struct {
MeshId string
IfName string
NodeId string
Client *wgctrl.Client
doc *automerge.Doc
LastHash automerge.ChangeHash
conf *conf.WgMeshConfiguration
cache *MeshCrdt
lastCacheHash automerge.ChangeHash
}
func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
@ -34,15 +35,78 @@ func (c *CrdtMeshManager) AddNode(node mesh.MeshNode) {
panic("node must be of type *MeshNodeCrdt")
}
crdt.Routes = make(map[string]Route)
crdt.Services = make(map[string]string)
crdt.Timestamp = time.Now().Unix()
c.doc.Path("nodes").Map().Set(crdt.HostEndpoint, crdt)
nodeVal, _ := c.doc.Path("nodes").Map().Get(crdt.HostEndpoint)
nodeVal.Map().Set("routes", automerge.NewMap())
c.doc.Path("nodes").Map().Set(crdt.PublicKey, crdt)
}
func (c *CrdtMeshManager) isPeer(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
nodeType, err := node.Map().Get("type")
if err != nil || nodeType.Kind() != automerge.KindStr {
return false
}
return nodeType.Str() == string(conf.PEER_ROLE)
}
// isAlive: checks that the node's configuration has been updated
// since the rquired keep alive time
func (c *CrdtMeshManager) isAlive(nodeId string) bool {
node, err := c.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return false
}
timestamp, err := node.Map().Get("timestamp")
if err != nil || timestamp.Kind() != automerge.KindInt64 {
return false
}
keepAliveTime := timestamp.Int64()
return (time.Now().Unix() - keepAliveTime) < int64(c.conf.DeadTime)
}
func (c *CrdtMeshManager) GetPeers() []string {
keys, _ := c.doc.Path("nodes").Map().Keys()
keys = lib.Filter(keys, func(publicKey string) bool {
return c.isPeer(publicKey) && c.isAlive(publicKey)
})
return keys
}
// GetMesh(): Converts the document into a struct
func (c *CrdtMeshManager) GetMesh() (mesh.MeshSnapshot, error) {
return automerge.As[*MeshCrdt](c.doc.Root())
changes, err := c.doc.Changes(c.lastCacheHash)
if err != nil {
return nil, err
}
if c.cache == nil || len(changes) > 0 {
c.lastCacheHash = c.LastHash
cache, err := automerge.As[*MeshCrdt](c.doc.Root())
if err != nil {
return nil, err
}
c.cache = cache
}
return c.cache, nil
}
// GetMeshId returns the meshid of the mesh
@ -82,13 +146,23 @@ func NewCrdtNodeManager(params *NewCrdtNodeMangerParams) (*CrdtMeshManager, erro
manager.IfName = params.DevName
manager.Client = params.Client
manager.conf = &params.Conf
manager.cache = nil
return &manager, nil
}
// GetNode: returns a mesh node crdt.Close releases resources used by a Client.
func (m *CrdtMeshManager) GetNode(endpoint string) (*MeshNodeCrdt, error) {
// NodeExists: returns true if the node exists. Returns false
func (m *CrdtMeshManager) NodeExists(key string) bool {
node, err := m.doc.Path("nodes").Map().Get(key)
return node.Kind() == automerge.KindMap && err == nil
}
func (m *CrdtMeshManager) GetNode(endpoint string) (mesh.MeshNode, error) {
node, err := m.doc.Path("nodes").Map().Get(endpoint)
if node.Kind() != automerge.KindMap {
return nil, fmt.Errorf("GetNode: something went wrong %s is not a map type")
}
if err != nil {
return nil, err
}
@ -178,8 +252,74 @@ func (m *CrdtMeshManager) SetDescription(nodeId string, description string) erro
return err
}
func (m *CrdtMeshManager) SetAlias(nodeId string, alias string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if node.Kind() != automerge.KindMap {
return fmt.Errorf("%s does not exist", nodeId)
}
err = node.Map().Set("alias", alias)
if err == nil {
logging.Log.WriteInfof("Updated Alias for %s to %s", nodeId, alias)
}
return err
}
func (m *CrdtMeshManager) AddService(nodeId, key, value string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("AddService: services property does not exist in node")
}
return service.Map().Set(key, value)
}
func (m *CrdtMeshManager) RemoveService(nodeId, key string) error {
node, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil || node.Kind() != automerge.KindMap {
return fmt.Errorf("RemoveService: node %s does not exist", nodeId)
}
service, err := node.Map().Get("services")
if err != nil {
return err
}
if service.Kind() != automerge.KindMap {
return fmt.Errorf("services property does not exist")
}
err = service.Map().Delete(key)
if err != nil {
return fmt.Errorf("service %s does not exist", key)
}
return nil
}
// AddRoutes: adds routes to the specific nodeId
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...mesh.Route) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
logging.Log.WriteInfof("Adding route to %s", nodeId)
@ -198,20 +338,160 @@ func (m *CrdtMeshManager) AddRoutes(nodeId string, routes ...string) error {
}
for _, route := range routes {
err = routeMap.Map().Set(route, struct{}{})
err = routeMap.Map().Set(route.GetDestination().String(), Route{
Destination: route.GetDestination().String(),
Path: route.GetPath(),
})
if err != nil {
return err
}
}
return nil
}
func (m *CrdtMeshManager) getRoutes(nodeId string) ([]Route, error) {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return nil, err
}
if nodeVal.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node does not exist")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return nil, err
}
if routeMap.Kind() != automerge.KindMap {
return nil, fmt.Errorf("node %s is not a map", nodeId)
}
routes, err := automerge.As[map[string]Route](routeMap)
return lib.MapValues(routes), err
}
func (m *CrdtMeshManager) GetRoutes(targetNode string) (map[string]mesh.Route, error) {
node, err := m.GetNode(targetNode)
if err != nil {
return nil, err
}
routes := make(map[string]mesh.Route)
for _, route := range node.GetRoutes() {
routes[route.GetDestination().String()] = route
}
for _, node := range m.GetPeers() {
nodeRoutes, err := m.getRoutes(node)
if err != nil {
return nil, err
}
for _, route := range nodeRoutes {
otherRoute, ok := routes[route.GetDestination().String()]
if !ok || route.GetHopCount()+1 < otherRoute.GetHopCount() {
routes[route.GetDestination().String()] = &Route{
Destination: route.GetDestination().String(),
Path: append(route.Path, m.GetMeshId()),
}
}
}
}
return routes, nil
}
// DeleteRoutes deletes the specified routes
func (m *CrdtMeshManager) RemoveRoutes(nodeId string, routes ...string) error {
nodeVal, err := m.doc.Path("nodes").Map().Get(nodeId)
if err != nil {
return err
}
if nodeVal.Kind() != automerge.KindMap {
return fmt.Errorf("node is not a map")
}
routeMap, err := nodeVal.Map().Get("routes")
if err != nil {
return err
}
for _, route := range routes {
err = routeMap.Map().Delete(route)
}
return err
}
func (m *CrdtMeshManager) GetSyncer() mesh.MeshSyncer {
return NewAutomergeSync(m)
}
func (m *CrdtMeshManager) Prune(pruneTime int) error {
nodes, err := m.doc.Path("nodes").Get()
if err != nil {
return err
}
if nodes.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
values, err := nodes.Map().Values()
if err != nil {
return err
}
deletionNodes := make([]string, 0)
for nodeId, node := range values {
if node.Kind() != automerge.KindMap {
return errors.New("node must be a map")
}
nodeMap := node.Map()
timeStamp, err := nodeMap.Get("timestamp")
if err != nil {
return err
}
if timeStamp.Kind() != automerge.KindInt64 {
return errors.New("timestamp is not int64")
}
timeValue := timeStamp.Int64()
nowValue := time.Now().Unix()
if nowValue-timeValue >= int64(pruneTime) {
deletionNodes = append(deletionNodes, nodeId)
}
}
for _, node := range deletionNodes {
logging.Log.WriteInfof("Pruning %s", node)
nodes.Map().Delete(node)
}
return nil
}
func (m1 *MeshNodeCrdt) Compare(m2 *MeshNodeCrdt) int {
return strings.Compare(m1.PublicKey, m2.PublicKey)
}
@ -243,8 +523,13 @@ func (m *MeshNodeCrdt) GetTimeStamp() int64 {
return m.Timestamp
}
func (m *MeshNodeCrdt) GetRoutes() []string {
return lib.MapKeys(m.Routes)
func (m *MeshNodeCrdt) GetRoutes() []mesh.Route {
return lib.Map(lib.MapValues(m.Routes), func(r Route) mesh.Route {
return &Route{
Destination: r.Destination,
Path: r.Path,
}
})
}
func (m *MeshNodeCrdt) GetDescription() string {
@ -260,6 +545,26 @@ func (m *MeshNodeCrdt) GetIdentifier() string {
return strings.Join(constituents, ":")
}
func (m *MeshNodeCrdt) GetAlias() string {
return m.Alias
}
func (m *MeshNodeCrdt) GetServices() map[string]string {
services := make(map[string]string)
for key, service := range m.Services {
services[key] = service
}
return services
}
// GetType refers to the type of the node. Peer means that the node is globally accessible
// Client means the node is only accessible through another peer
func (n *MeshNodeCrdt) GetType() conf.NodeType {
return conf.NodeType(n.Type)
}
func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
nodes := make(map[string]mesh.MeshNode)
@ -272,8 +577,24 @@ func (m *MeshCrdt) GetNodes() map[string]mesh.MeshNode {
Timestamp: node.Timestamp,
Routes: node.Routes,
Description: node.Description,
Alias: node.Alias,
Services: node.GetServices(),
Type: node.Type,
}
}
return nodes
}
func (r *Route) GetDestination() *net.IPNet {
_, ipnet, _ := net.ParseCIDR(r.Destination)
return ipnet
}
func (r *Route) GetHopCount() int {
return len(r.Path)
}
func (r *Route) GetPath() []string {
return r.Path
}

View File

@ -28,14 +28,23 @@ type MeshNodeFactory struct {
func (f *MeshNodeFactory) Build(params *mesh.MeshNodeFactoryParams) mesh.MeshNode {
hostName := f.getAddress(params)
grpcEndpoint := fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort)
if f.Config.Role == conf.CLIENT_ROLE {
grpcEndpoint = "-"
}
return &MeshNodeCrdt{
HostEndpoint: fmt.Sprintf("%s:%s", hostName, f.Config.GrpcPort),
HostEndpoint: grpcEndpoint,
PublicKey: params.PublicKey.String(),
WgEndpoint: fmt.Sprintf("%s:%d", hostName, params.WgPort),
WgHost: fmt.Sprintf("%s/128", params.NodeIP.String()),
// Always set the routes as empty.
// Routes handled by external component
Routes: map[string]interface{}{},
Routes: make(map[string]Route),
Description: "",
Alias: "",
Type: string(f.Config.Role),
}
}
@ -48,7 +57,19 @@ func (f *MeshNodeFactory) getAddress(params *mesh.MeshNodeFactoryParams) string
} else if len(f.Config.Endpoint) != 0 {
hostName = f.Config.Endpoint
} else {
hostName = lib.GetOutboundIP().String()
ipFunc := lib.GetPublicIP
if f.Config.IPDiscovery == conf.DNS_IP_DISCOVERY {
ipFunc = lib.GetOutboundIP
}
ip, err := ipFunc()
if err != nil {
return ""
}
hostName = ip.String()
}
return hostName

View File

@ -1,5 +1,11 @@
package crdt
// Route: Represents a CRDT of the given route
type Route struct {
Destination string `automerge:"destination"`
Path []string `automerge:"path"`
}
// MeshNodeCrdt: Represents a CRDT for a mesh nodes
type MeshNodeCrdt struct {
HostEndpoint string `automerge:"hostEndpoint"`
@ -7,8 +13,11 @@ type MeshNodeCrdt struct {
PublicKey string `automerge:"publicKey"`
WgHost string `automerge:"wgHost"`
Timestamp int64 `automerge:"timestamp"`
Routes map[string]interface{} `automerge:"routes"`
Routes map[string]Route `automerge:"routes"`
Alias string `automerge:"alias"`
Description string `automerge:"description"`
Services map[string]string `automerge:"services"`
Type string `automerge:"type"`
}
// MeshCrdt: Represents the mesh network as a whole

View File

@ -16,6 +16,20 @@ func (m *WgMeshConfigurationError) Error() string {
return m.msg
}
type NodeType string
const (
PEER_ROLE NodeType = "peer"
CLIENT_ROLE NodeType = "client"
)
type IPDiscovery string
const (
PUBLIC_IP_DISCOVERY = "public"
DNS_IP_DISCOVERY = "dns"
)
type WgMeshConfiguration struct {
// CertificatePath is the path to the certificate to use in mTLS
CertificatePath string `yaml:"certificatePath"`
@ -28,17 +42,43 @@ type WgMeshConfiguration struct {
SkipCertVerification bool `yaml:"skipCertVerification"`
// Port to run the GrpcServer on
GrpcPort string `yaml:"gRPCPort"`
// IPDIscovery: how to discover your IP if not specified. Use DNS server 8.8.8.8 or
// use public IP discovery library
IPDiscovery IPDiscovery `yaml:"ipDiscovery"`
// AdvertiseRoutes advertises other meshes if the node is in multiple meshes
AdvertiseRoutes bool `yaml:"advertiseRoutes"`
// Endpoint is the IP in which this computer is publicly reachable.
// usecase is when the node has multiple IP addresses
Endpoint string `yaml:"publicEndpoint"`
// ClusterSize size of the cluster to split on
ClusterSize int `yaml:"clusterSize"`
// SyncRate number of times per second to perform a sync
SyncRate float64 `yaml:"syncRate"`
// InterClusterChance proability of inter-cluster communication in a sync round
InterClusterChance float64 `yaml:"interClusterChance"`
// BranchRate number of nodes to randomly communicate with
BranchRate int `yaml:"branchRate"`
// InfectionCount number of times we sync before we can no longer catch the udpate
InfectionCount int `yaml:"infectionCount"`
KeepAliveRate int `yaml:"keepAliveRate"`
// KeepAliveTime number of seconds before we update node indicating that we are still alive
KeepAliveTime int `yaml:"keepAliveTime"`
// Timeout number of seconds before we consider the node as dead
Timeout int `yaml:"timeout"`
// PruneTime number of seconds before we remove nodes that are likely to be dead
PruneTime int `yaml:"pruneTime"`
// DeadTime: number of seconds before we consider the node as dead and stop considering it
// when picking a random peer
DeadTime int `yaml:"deadTime"`
// Profile whether or not to include a http server that profiles the code
Profile bool `yaml:"profile"`
// StubWg whether or not to stub the WireGuard types
StubWg bool `yaml:"stubWg"`
// Role specifies whether or not the user is globally accessible.
// If the user is globaly accessible they specify themselves as a client.
Role NodeType `yaml:"role"`
// KeepAliveWg configures the implementation so that we send keep alive packets to peers.
// KeepAlive can only be set if role is type client
KeepAliveWg int `yaml:"keepAliveWg"`
}
func ValidateConfiguration(c *WgMeshConfiguration) error {
@ -90,7 +130,7 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
}
}
if c.KeepAliveRate <= 0 {
if c.KeepAliveTime <= 0 {
return &WgMeshConfigurationError{
msg: "KeepAliveRate cannot be less than negative",
}
@ -102,6 +142,38 @@ func ValidateConfiguration(c *WgMeshConfiguration) error {
}
}
if c.Timeout < 1 {
return &WgMeshConfigurationError{
msg: "Timeout should be greater than or equal to 1",
}
}
if c.PruneTime < 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be < 1",
}
}
if c.DeadTime < 1 {
return &WgMeshConfigurationError{
msg: "Dead time cannot be < 1",
}
}
if c.KeepAliveTime <= 1 {
return &WgMeshConfigurationError{
msg: "Prune time cannot be less than keep alive time",
}
}
if c.Role == "" {
c.Role = PEER_ROLE
}
if c.IPDiscovery == "" {
c.IPDiscovery = PUBLIC_IP_DISCOVERY
}
return nil
}

View File

@ -15,8 +15,10 @@ func getExampleConfiguration() *WgMeshConfiguration {
SyncRate: 1,
InterClusterChance: 0.1,
BranchRate: 2,
KeepAliveRate: 1,
KeepAliveTime: 4,
InfectionCount: 1,
Timeout: 2,
PruneTime: 20,
}
}
@ -110,7 +112,7 @@ func InfectionCountZero(t *testing.T) {
func KeepAliveRateZero(t *testing.T) {
conf := getExampleConfiguration()
conf.KeepAliveRate = 0
conf.KeepAliveTime = 0
err := ValidateConfiguration(conf)
@ -128,3 +130,36 @@ func TestValidCOnfiguration(t *testing.T) {
t.Error(err)
}
}
func TestTimeout(t *testing.T) {
conf := getExampleConfiguration()
conf.Timeout = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatal(`error should be thrown`)
}
}
func TestPruneTimeZero(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 0
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}
func TestPruneTimeLessThanKeepAliveTime(t *testing.T) {
conf := getExampleConfiguration()
conf.PruneTime = 1
err := ValidateConfiguration(conf)
if err == nil {
t.Fatalf(`Error should be thrown`)
}
}

View File

@ -31,10 +31,12 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
nodeFactory := crdt.MeshNodeFactory{
Config: *params.Conf,
}
idGenerator := &lib.UUIDGenerator{}
idGenerator := &lib.IDNameGenerator{}
ipAllocator := &ip.ULABuilder{}
interfaceManipulator := wg.NewWgInterfaceManipulator(params.Client)
configApplyer := mesh.NewWgMeshConfigApplyer(params.Conf)
meshManagerParams := &mesh.NewMeshManagerParams{
Conf: *params.Conf,
Client: params.Client,
@ -43,9 +45,11 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
IdGenerator: idGenerator,
IPAllocator: ipAllocator,
InterfaceManipulator: interfaceManipulator,
ConfigApplyer: mesh.NewWgMeshConfigApplyer(ctrlServer.MeshManager),
ConfigApplyer: configApplyer,
}
ctrlServer.MeshManager = mesh.NewMeshManager(meshManagerParams)
configApplyer.SetMeshManager(ctrlServer.MeshManager)
ctrlServer.Conf = params.Conf
connManagerParams := conn.NewConnectionManagerParams{
@ -107,6 +111,10 @@ func (s *MeshCtrlServer) Close() error {
logging.Log.WriteErrorf(err.Error())
}
if err := s.MeshManager.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err := s.ConnectionServer.Close(); err != nil {
logging.Log.WriteErrorf(err.Error())
}

View File

@ -9,6 +9,11 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type MeshRoute struct {
Destination string
Path []string
}
// Represents a WireGuard MeshNode
type MeshNode struct {
HostEndpoint string
@ -16,7 +21,10 @@ type MeshNode struct {
PublicKey string
WgHost string
Timestamp int64
Routes []string
Routes []MeshRoute
Description string
Alias string
Services map[string]string
}
// Represents a WireGuard Mesh

114
pkg/dns/dns.go Normal file
View File

@ -0,0 +1,114 @@
package smegdns
import (
"encoding/json"
"fmt"
"net"
"net/rpc"
"github.com/miekg/dns"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/query"
)
const SockAddr = "/tmp/wgmesh_ipc.sock"
const MeshRegularExpression = `(?P<meshId>.+)\.(?P<alias>.+)\.smeg\.`
type DNSHandler struct {
client *rpc.Client
server *dns.Server
}
// queryMesh: queries the mesh network for the given meshId and node
// with alias
func (d *DNSHandler) queryMesh(meshId, alias string) net.IP {
var reply string
err := d.client.Call("IpcHandler.Query", &ipc.QueryMesh{
MeshId: meshId,
Query: fmt.Sprintf("[?alias == '%s'] | [0]", alias),
}, &reply)
if err != nil {
return nil
}
var node *query.QueryNode
err = json.Unmarshal([]byte(reply), &node)
if err != nil || node == nil {
return nil
}
ip, _, _ := net.ParseCIDR(node.WgHost)
return ip
}
func (d *DNSHandler) handleQuery(m *dns.Msg) {
for _, q := range m.Question {
switch q.Qtype {
case dns.TypeAAAA:
logging.Log.WriteInfof("Query for %s", q.Name)
groups := lib.MatchCaptureGroup(MeshRegularExpression, q.Name)
if len(groups) == 0 {
continue
}
ip := d.queryMesh(groups["meshId"], groups["alias"])
rr, err := dns.NewRR(fmt.Sprintf("%s AAAA %s", q.Name, ip))
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
if err == nil {
m.Answer = append(m.Answer, rr)
}
}
}
}
func (h *DNSHandler) handleDnsRequest(w dns.ResponseWriter, r *dns.Msg) {
msg := new(dns.Msg)
msg.SetReply(r)
msg.Authoritative = true
switch r.Opcode {
case dns.OpcodeQuery:
h.handleQuery(msg)
}
w.WriteMsg(msg)
}
func (h *DNSHandler) Listen() error {
return h.server.ListenAndServe()
}
func (h *DNSHandler) Close() error {
return h.server.Shutdown()
}
func NewDns(udpPort int) (*DNSHandler, error) {
client, err := rpc.DialHTTP("unix", SockAddr)
if err != nil {
return nil, err
}
dnsHander := DNSHandler{
client: client,
}
dns.HandleFunc("smeg.", dnsHander.handleDnsRequest)
dnsHander.server = &dns.Server{Addr: fmt.Sprintf(":%d", udpPort), Net: "udp"}
return &dnsHander, nil
}

View File

@ -4,13 +4,13 @@ package rpctypes;
option go_package = "pkg/rpc";
service MeshCtrlServer {
rpc JoinMesh(JoinMeshRequest) returns (JoinMeshReply) {}
rpc GetMesh(GetMeshRequest) returns (GetMeshReply) {}
}
message JoinMeshRequest {
string meshId = 2;
message GetMeshRequest {
string meshId = 1;
}
message JoinMeshReply {
bool success = 1;
message GetMeshReply {
bytes mesh = 1;
}

132
pkg/hosts/hosts.go Normal file
View File

@ -0,0 +1,132 @@
// hosts: utility for modifying the /etc/hosts file
package hosts
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"strings"
)
// HOSTS_FILE is the hosts file location
const HOSTS_FILE = "/etc/hosts"
const DOMAIN_HEADER = "#WG AUTO GENERATED HOSTS"
const DOMAIN_TRAILER = "#WG AUTO GENERATED HOSTS END"
type HostsEntry struct {
Alias string
Ip net.IP
}
// Generic interface to manipulate /etc/hosts file
type HostsManipulator interface {
// AddrAddr associates an aliasd with a given IP address
AddAddr(hosts ...HostsEntry)
// Remove deletes the entry from /etc/hosts
Remove(hosts ...HostsEntry)
// Writes the changes to /etc/hosts file
Write() error
}
type HostsManipulatorImpl struct {
hosts map[string]HostsEntry
}
// AddAddr implements HostsManipulator.
func (m *HostsManipulatorImpl) AddAddr(hosts ...HostsEntry) {
changed := false
for _, host := range hosts {
prev, ok := m.hosts[host.Ip.String()]
if !ok || prev.Alias != host.Alias {
changed = true
}
m.hosts[host.Ip.String()] = host
}
if changed {
m.Write()
}
}
// Remove implements HostsManipulator.
func (m *HostsManipulatorImpl) Remove(hosts ...HostsEntry) {
lenBefore := len(m.hosts)
for _, host := range hosts {
delete(m.hosts, host.Alias)
}
if lenBefore != len(m.hosts) {
m.Write()
}
}
func (m *HostsManipulatorImpl) removeHosts() string {
hostsFile, err := os.ReadFile(HOSTS_FILE)
if err != nil {
return ""
}
var contents strings.Builder
scanner := bufio.NewScanner(bytes.NewReader(hostsFile))
hostsSection := false
for scanner.Scan() {
line := scanner.Text()
if err == io.EOF {
break
} else if err != nil {
return ""
}
if !hostsSection && strings.Contains(line, DOMAIN_HEADER) {
hostsSection = true
}
if !hostsSection {
contents.WriteString(line + "\n")
}
if hostsSection && strings.Contains(line, DOMAIN_TRAILER) {
hostsSection = false
}
}
if scanner.Err() != nil && scanner.Err() != io.EOF {
return ""
}
return contents.String()
}
// Write implements HostsManipulator
func (m *HostsManipulatorImpl) Write() error {
contents := m.removeHosts()
var nextHosts strings.Builder
nextHosts.WriteString(contents)
nextHosts.WriteString(DOMAIN_HEADER + "\n")
for _, host := range m.hosts {
nextHosts.WriteString(fmt.Sprintf("%s\t%s\n", host.Ip.String(), host.Alias))
}
nextHosts.WriteString(DOMAIN_TRAILER + "\n")
return os.WriteFile(HOSTS_FILE, []byte(nextHosts.String()), 0644)
}
func NewHostsManipulator() HostsManipulator {
return &HostsManipulatorImpl{hosts: make(map[string]HostsEntry)}
}

View File

@ -11,8 +11,6 @@ import (
)
type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose
WgPort int
// Endpoint is the routable alias of the machine. Can be an IP
@ -25,13 +23,19 @@ type JoinMeshArgs struct {
MeshId string
// IpAddress is a routable IP in another mesh
IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose
Port int
// Endpoint is the routable address of this machine. If not provided
// defaults to the default address
Endpoint string
// Client specifies whether we should join as a client of the peer
// we are connecting to
Client bool
}
type PutServiceArgs struct {
Service string
Value string
}
type GetMeshReply struct {
@ -47,23 +51,31 @@ type QueryMesh struct {
Query string
}
type GetNodeArgs struct {
NodeId string
MeshId string
}
type MeshIpc interface {
CreateMesh(args *NewMeshArgs, reply *string) error
ListMeshes(name string, reply *ListMeshReply) error
JoinMesh(args JoinMeshArgs, reply *string) error
LeaveMesh(meshId string, reply *string) error
GetMesh(meshId string, reply *GetMeshReply) error
EnableInterface(meshId string, reply *string) error
GetDOT(meshId string, reply *string) error
Query(query QueryMesh, reply *string) error
PutDescription(description string, reply *string) error
PutAlias(alias string, reply *string) error
PutService(args PutServiceArgs, reply *string) error
GetNode(args GetNodeArgs, reply *string) error
DeleteService(service string, reply *string) error
}
const SockAddr = "/tmp/wgmesh_ipc.sock"
func RunIpcHandler(server MeshIpc) error {
if err := os.RemoveAll(SockAddr); err != nil {
return errors.New("Could not find to address")
return errors.New("could not find to address")
}
rpc.Register(server)

View File

@ -66,3 +66,13 @@ func Filter[V any](list []V, f filterFunc[V]) []V {
return newList
}
func Contains[V any](list []V, proposition func(V) bool) bool {
for _, elem := range list {
if proposition(elem) {
return true
}
}
return false
}

46
pkg/lib/hashing.go Normal file
View File

@ -0,0 +1,46 @@
package lib
import (
"hash/fnv"
"sort"
)
type consistentHashRecord[V any] struct {
record V
value int
}
func HashString(value string) int {
f := fnv.New32a()
f.Write([]byte(value))
return int(f.Sum32())
}
// ConsistentHash implementation. Traverse the values until we find a key
// less than ours.
func ConsistentHash[V any, K any](values []V, client K, bucketFunc func(V) int, keyFunc func(K) int) V {
if len(values) == 0 {
panic("values is empty")
}
vs := Map(values, func(v V) consistentHashRecord[V] {
return consistentHashRecord[V]{
v,
bucketFunc(v),
}
})
sort.SliceStable(vs, func(i, j int) bool {
return vs[i].value < vs[j].value
})
ourKey := keyFunc(client)
for _, record := range vs {
if ourKey < record.value {
return record.record
}
}
return vs[0].record
}

View File

@ -1,6 +1,9 @@
package lib
import "github.com/google/uuid"
import (
"github.com/anandvarma/namegen"
"github.com/google/uuid"
)
// IdGenerator generates unique ids
type IdGenerator interface {
@ -15,3 +18,11 @@ func (g *UUIDGenerator) GetId() (string, error) {
id := uuid.New()
return id.String(), nil
}
type IDNameGenerator struct {
}
func (i *IDNameGenerator) GetId() (string, error) {
name_schema := namegen.New()
return name_schema.Get(), nil
}

View File

@ -1,17 +1,61 @@
package lib
import (
"encoding/json"
"io"
"log"
"net"
"net/http"
)
// GetOutboundIP: gets the oubound IP of this packet
func GetOutboundIP() net.IP {
func GetOutboundIP() (net.IP, error) {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
log.Fatal(err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP
return localAddr.IP, nil
}
const IP_SERVICE = "https://api.ipify.org?format=json"
type IpResponse struct {
Ip string `json:"ip"`
}
func (i *IpResponse) GetIP() net.IP {
return net.ParseIP(i.Ip)
}
// GetPublicIP: get the nodes public IP address. For when a node is behind NAT
func GetPublicIP() (net.IP, error) {
req, err := http.NewRequest(http.MethodGet, IP_SERVICE, nil)
if err != nil {
return nil, err
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
resBody, err := io.ReadAll(res.Body)
if err != nil {
return nil, err
}
var jsonResponse IpResponse
err = json.Unmarshal([]byte(resBody), &jsonResponse)
if err != nil {
return nil, err
}
return jsonResponse.GetIP(), nil
}

19
pkg/lib/regex.go Normal file
View File

@ -0,0 +1,19 @@
package lib
import "regexp"
func MatchCaptureGroup(pattern, payload string) map[string]string {
patterns := make(map[string]string)
expr := regexp.MustCompile(pattern)
match := expr.FindStringSubmatch(payload)
for i, name := range expr.SubexpNames() {
if i != 0 && name != "" {
patterns[name] = match[i]
}
}
return patterns
}

293
pkg/lib/rtnetlink.go Normal file
View File

@ -0,0 +1,293 @@
package lib
import (
"encoding/binary"
"fmt"
"net"
"github.com/jsimonetti/rtnetlink"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.org/x/sys/unix"
)
type RtNetlinkConfig struct {
conn *rtnetlink.Conn
}
func NewRtNetlinkConfig() (*RtNetlinkConfig, error) {
conn, err := rtnetlink.Dial(nil)
if err != nil {
return nil, err
}
return &RtNetlinkConfig{conn: conn}, nil
}
const WIREGUARD_MTU = 1420
// Create a netlink interface if it does not exist. ifName is the name of the netlink interface
func (c *RtNetlinkConfig) CreateLink(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
return fmt.Errorf("interface %s already exists", ifName)
}
err = c.conn.Link.New(&rtnetlink.LinkMessage{
Family: unix.AF_UNSPEC,
Flags: unix.IFF_UP,
Attributes: &rtnetlink.LinkAttributes{
Name: ifName,
Info: &rtnetlink.LinkInfo{Kind: "wireguard"},
MTU: uint32(WIREGUARD_MTU),
},
})
if err != nil {
return fmt.Errorf("failed to create wireguard interface: %w", err)
}
return nil
}
// Delete link delete the specified interface
func (c *RtNetlinkConfig) DeleteLink(ifName string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s %w", ifName, err)
}
err = c.conn.Link.Delete(uint32(iface.Index))
if err != nil {
return fmt.Errorf("failed to delete wg interface %w", err)
}
return nil
}
// AddAddress adds an address to the given interface.
func (c *RtNetlinkConfig) AddAddress(ifName string, address string) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed to get interface %s error: %w", ifName, err)
}
addr, cidr, err := net.ParseCIDR(address)
if err != nil {
return fmt.Errorf("failed to parse CIDR %s error: %w", addr, err)
}
family := unix.AF_INET6
ipv4 := cidr.IP.To4()
if ipv4 != nil {
family = unix.AF_INET
}
// Calculate the prefix length
ones, _ := cidr.Mask.Size()
// Calculate the broadcast IP
// Only used when family is AF_INET
var brd net.IP
if ipv4 != nil {
brd = make(net.IP, len(ipv4))
binary.BigEndian.PutUint32(brd, binary.BigEndian.Uint32(ipv4)|^binary.BigEndian.Uint32(net.IP(cidr.Mask).To4()))
}
err = c.conn.Address.New(&rtnetlink.AddressMessage{
Family: uint8(family),
PrefixLength: uint8(ones),
Scope: unix.RT_SCOPE_UNIVERSE,
Index: uint32(iface.Index),
Attributes: &rtnetlink.AddressAttributes{
Address: addr,
Local: addr,
Broadcast: brd,
},
})
if err != nil {
err = fmt.Errorf("failed to add address to link %w", err)
}
return err
}
// AddRoute: adds a route to the routing table.
// ifName is the intrface to add the route to
// gateway is the IP of the gateway device to hop to
// dst is the network prefix of the advertised destination
func (c *RtNetlinkConfig) AddRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Replace(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to add route %w", err)
}
return nil
}
// DeleteRoute deletes routes with the gateway and destination
func (c *RtNetlinkConfig) DeleteRoute(ifName string, route Route) error {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
gw := route.Gateway
dst := route.Destination
var family uint8 = unix.AF_INET6
if dst.IP.To4() != nil {
family = unix.AF_INET
}
attr := rtnetlink.RouteAttributes{
Dst: dst.IP,
OutIface: uint32(iface.Index),
Gateway: gw,
}
ones, _ := dst.Mask.Size()
err = c.conn.Route.Delete(&rtnetlink.RouteMessage{
Family: family,
Table: unix.RT_TABLE_MAIN,
Protocol: unix.RTPROT_BOOT,
Scope: unix.RT_SCOPE_LINK,
Type: unix.RTN_UNICAST,
DstLength: uint8(ones),
Attributes: attr,
})
if err != nil {
return fmt.Errorf("failed to delete route %s", dst.IP.String())
}
return nil
}
type Route struct {
Gateway net.IP
Destination net.IPNet
}
func (r1 Route) equal(r2 Route) bool {
return r1.Gateway.String() == r2.Gateway.String() &&
r1.Destination.String() == r2.Destination.String()
}
// DeleteRoutes deletes all routes not in exclude
func (c *RtNetlinkConfig) DeleteRoutes(ifName string, family uint8, exclude ...Route) error {
routes, err := c.listRoutes(ifName, family)
if err != nil {
return err
}
ifRoutes := make([]Route, 0)
for _, rtRoute := range routes {
maskSize := 128
if family == unix.AF_INET {
maskSize = 32
}
cidr := net.CIDRMask(int(rtRoute.DstLength), maskSize)
route := Route{
Gateway: rtRoute.Attributes.Gateway,
Destination: net.IPNet{IP: rtRoute.Attributes.Dst, Mask: cidr},
}
ifRoutes = append(ifRoutes, route)
}
shouldExclude := func(r Route) bool {
for _, route := range exclude {
if route.equal(r) {
return false
}
}
return true
}
toDelete := Filter(ifRoutes, shouldExclude)
for _, route := range toDelete {
logging.Log.WriteInfof("Deleting route: %s", route.Gateway.String())
err := c.DeleteRoute(ifName, route)
if err != nil {
return err
}
}
return nil
}
// listRoutes lists all routes on the interface
func (c *RtNetlinkConfig) listRoutes(ifName string, family uint8) ([]rtnetlink.RouteMessage, error) {
iface, err := net.InterfaceByName(ifName)
if err != nil {
return nil, fmt.Errorf("failed accessing interface %s error %w", ifName, err)
}
routes, err := c.conn.Route.List()
if err != nil {
return nil, fmt.Errorf("failed to get route %w", err)
}
filterFunc := func(r rtnetlink.RouteMessage) bool {
return r.Attributes.Gateway != nil && r.Attributes.OutIface == uint32(iface.Index)
}
routes = Filter(routes, filterFunc)
return routes, nil
}
func (c *RtNetlinkConfig) Close() error {
return c.conn.Close()
}

42
pkg/lib/timer.go Normal file
View File

@ -0,0 +1,42 @@
package lib
import "time"
type TimerFunc = func() error
type Timer struct {
f TimerFunc
quit chan struct{}
updateRate int
}
func (t *Timer) Run() error {
ticker := time.NewTicker(time.Duration(t.updateRate) * time.Second)
t.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := t.f()
if err != nil {
return err
}
case <-t.quit:
break
}
}
}
func (t *Timer) Stop() error {
close(t.quit)
return nil
}
func NewTimer(f TimerFunc, updateRate int) *Timer {
return &Timer{
f: f,
updateRate: updateRate,
}
}

View File

@ -2,6 +2,7 @@
package logging
import (
"io"
"os"
"github.com/sirupsen/logrus"
@ -15,6 +16,7 @@ type Logger interface {
WriteInfof(msg string, args ...interface{})
WriteErrorf(msg string, args ...interface{})
WriteWarnf(msg string, args ...interface{})
Writer() io.Writer
}
type LogrusLogger struct {
@ -33,6 +35,10 @@ func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
l.logger.Warnf(msg, args...)
}
func (l *LogrusLogger) Writer() io.Writer {
return l.logger.Writer()
}
func NewLogrusLogger() *LogrusLogger {
logger := logrus.New()
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})

46
pkg/mesh/alias.go Normal file
View File

@ -0,0 +1,46 @@
package mesh
import (
"fmt"
"github.com/tim-beatham/wgmesh/pkg/hosts"
)
type MeshAliasManager interface {
AddAliases(nodes []MeshNode)
RemoveAliases(node []MeshNode)
}
type AliasManager struct {
hosts hosts.HostsManipulator
}
// AddAliases: on node update or change add aliases to the hosts file
func (a *AliasManager) AddAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.AddAddr(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
// RemoveAliases: on node remove remove aliases from the hosts file
func (a *AliasManager) RemoveAliases(nodes []MeshNode) {
for _, node := range nodes {
if node.GetAlias() != "" {
a.hosts.Remove(hosts.HostsEntry{
Alias: fmt.Sprintf("%s.smeg", node.GetAlias()),
Ip: node.GetWgHost().IP,
})
}
}
}
func NewAliasManager() MeshAliasManager {
return &AliasManager{
hosts: hosts.NewHostsManipulator(),
}
}

View File

@ -1,10 +1,15 @@
package mesh
import (
"errors"
"fmt"
"net"
"slices"
"time"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/route"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@ -12,14 +17,29 @@ import (
type MeshConfigApplyer interface {
ApplyConfig() error
RemovePeers(meshId string) error
SetMeshManager(manager MeshManager)
}
// WgMeshConfigApplyer applies WireGuard configuration
type WgMeshConfigApplyer struct {
meshManager MeshManager
config *conf.WgMeshConfiguration
routeInstaller route.RouteInstaller
}
func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
type routeNode struct {
gateway string
route Route
}
func (r *routeNode) equals(route2 *routeNode) bool {
return r.gateway == route2.gateway && RouteEquals(r.route, route2.route)
}
func (m *WgMeshConfigApplyer) convertMeshNode(node MeshNode, device *wgtypes.Device,
peerToClients map[string][]net.IPNet,
routes map[string][]routeNode) (*wgtypes.PeerConfig, error) {
endpoint, err := net.ResolveUDPAddr("udp", node.GetWgEndpoint())
if err != nil {
@ -35,20 +55,109 @@ func convertMeshNode(node MeshNode) (*wgtypes.PeerConfig, error) {
allowedips := make([]net.IPNet, 1)
allowedips[0] = *node.GetWgHost()
clients, ok := peerToClients[node.GetWgHost().String()]
if ok {
allowedips = append(allowedips, clients...)
}
for _, route := range node.GetRoutes() {
_, ipnet, _ := net.ParseCIDR(route)
allowedips = append(allowedips, *ipnet)
bestRoutes := routes[route.GetDestination().String()]
if len(bestRoutes) == 1 {
allowedips = append(allowedips, *route.GetDestination())
} else if len(bestRoutes) > 1 {
keyFunc := func(mn MeshNode) int {
pubKey, _ := mn.GetPublicKey()
return lib.HashString(pubKey.String())
}
bucketFunc := func(rn routeNode) int {
return lib.HashString(rn.gateway)
}
// Else there is more than one candidate so consistently hash
pickedRoute := lib.ConsistentHash(bestRoutes, node, bucketFunc, keyFunc)
if pickedRoute.gateway == pubKey.String() {
allowedips = append(allowedips, *route.GetDestination())
}
}
}
keepAlive := time.Duration(m.config.KeepAliveWg) * time.Second
existing := slices.IndexFunc(device.Peers, func(p wgtypes.Peer) bool {
pubKey, _ := node.GetPublicKey()
return p.PublicKey.String() == pubKey.String()
})
if existing != -1 {
endpoint = device.Peers[existing].Endpoint
}
peerConfig := wgtypes.PeerConfig{
PublicKey: pubKey,
Endpoint: endpoint,
AllowedIPs: allowedips,
PersistentKeepaliveInterval: &keepAlive,
}
return &peerConfig, nil
}
// getRoutes: finds the routes with the least hop distance. If more than one route exists
// consistently hash to evenly spread the distribution of traffic
func (m *WgMeshConfigApplyer) getRoutes(meshProvider MeshProvider) map[string][]routeNode {
mesh, _ := meshProvider.GetMesh()
routes := make(map[string][]routeNode)
meshPrefixes := lib.Map(lib.MapValues(m.meshManager.GetMeshes()), func(mesh MeshProvider) *net.IPNet {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
return ipNet
})
for _, node := range mesh.GetNodes() {
pubKey, _ := node.GetPublicKey()
meshRoutes, _ := meshProvider.GetRoutes(pubKey.String())
for _, route := range meshRoutes {
if lib.Contains(meshPrefixes, func(prefix *net.IPNet) bool {
if prefix == nil || route == nil || route.GetDestination() == nil {
return false
}
return prefix.Contains(route.GetDestination().IP)
}) {
continue
}
destination := route.GetDestination().String()
otherRoute, ok := routes[destination]
rn := routeNode{
gateway: pubKey.String(),
route: route,
}
if !ok {
otherRoute = make([]routeNode, 1)
otherRoute[0] = rn
routes[destination] = otherRoute
} else if route.GetHopCount() < otherRoute[0].route.GetHopCount() {
otherRoute[0] = rn
} else if otherRoute[0].route.GetHopCount() == route.GetHopCount() {
routes[destination] = append(otherRoute, rn)
}
}
}
return routes
}
func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
snap, err := mesh.GetMesh()
@ -56,18 +165,67 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
nodes := snap.GetNodes()
nodes := lib.MapValues(snap.GetNodes())
peerConfigs := make([]wgtypes.PeerConfig, len(nodes))
peers := lib.Filter(nodes, func(mn MeshNode) bool {
return mn.GetType() == conf.PEER_ROLE
})
var count int = 0
for _, n := range nodes {
peer, err := convertMeshNode(n)
self, err := m.meshManager.GetSelf(mesh.GetMeshId())
if err != nil {
return err
}
peerToClients := make(map[string][]net.IPNet)
routes := m.getRoutes(mesh)
installedRoutes := make([]lib.Route, 0)
for _, n := range nodes {
if NodeEquals(n, self) {
continue
}
if n.GetType() == conf.CLIENT_ROLE && len(peers) > 0 && self.GetType() == conf.CLIENT_ROLE {
hashFunc := func(mn MeshNode) int {
return lib.HashString(mn.GetWgHost().String())
}
peer := lib.ConsistentHash(peers, n, hashFunc, hashFunc)
clients, ok := peerToClients[peer.GetWgHost().String()]
if !ok {
clients = make([]net.IPNet, 0)
peerToClients[peer.GetWgHost().String()] = clients
}
peerToClients[peer.GetWgHost().String()] = append(clients, *n.GetWgHost())
continue
}
dev, _ := mesh.GetDevice()
peer, err := m.convertMeshNode(n, dev, peerToClients, routes)
if err != nil {
return err
}
for _, route := range peer.AllowedIPs {
ula := &ip.ULABuilder{}
ipNet, _ := ula.GetIPNet(mesh.GetMeshId())
if !ipNet.Contains(route.IP) {
installedRoutes = append(installedRoutes, lib.Route{
Gateway: n.GetWgHost().IP,
Destination: route,
})
}
}
peerConfigs[count] = *peer
count++
}
@ -82,6 +240,12 @@ func (m *WgMeshConfigApplyer) updateWgConf(mesh MeshProvider) error {
return err
}
err = m.routeInstaller.InstallRoutes(dev.Name, installedRoutes...)
if err != nil {
return err
}
return m.meshManager.GetClient().ConfigureDevice(dev.Name, cfg)
}
@ -101,7 +265,7 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
mesh := m.meshManager.GetMesh(meshId)
if mesh == nil {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return fmt.Errorf("mesh %s does not exist", meshId)
}
dev, err := mesh.GetDevice()
@ -112,12 +276,19 @@ func (m *WgMeshConfigApplyer) RemovePeers(meshId string) error {
m.meshManager.GetClient().ConfigureDevice(dev.Name, wgtypes.Config{
ReplacePeers: true,
Peers: make([]wgtypes.PeerConfig, 1),
Peers: make([]wgtypes.PeerConfig, 0),
})
return nil
}
func NewWgMeshConfigApplyer(manager MeshManager) MeshConfigApplyer {
return &WgMeshConfigApplyer{meshManager: manager}
func (m *WgMeshConfigApplyer) SetMeshManager(manager MeshManager) {
m.meshManager = manager
}
func NewWgMeshConfigApplyer(config *conf.WgMeshConfiguration) MeshConfigApplyer {
return &WgMeshConfigApplyer{
config: config,
routeInstaller: route.NewRouteInstaller(),
}
}

View File

@ -61,7 +61,7 @@ func (c *MeshDOTConverter) graphNode(g *graph.Graph, node MeshNode, meshId strin
self, _ := c.manager.GetSelf(meshId)
if node.GetHostEndpoint() == self.GetHostEndpoint() {
if NodeEquals(self, node) {
return
}

View File

@ -14,20 +14,27 @@ import (
)
type MeshManager interface {
CreateMesh(devName string, port int) (string, error)
CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider
EnableInterface(meshId string) error
GetPublicKey(meshId string) (*wgtypes.Key, error)
AddSelf(params *AddSelfParams) error
LeaveMesh(meshId string) error
GetSelf(meshId string) (MeshNode, error)
ApplyConfig() error
SetDescription(description string) error
SetAlias(alias string) error
SetService(service string, value string) error
RemoveService(service string) error
UpdateTimeStamp() error
GetClient() *wgctrl.Client
GetMeshes() map[string]MeshProvider
Prune() error
Close() error
GetMonitor() MeshMonitor
GetNode(string, string) MeshNode
GetRouteManager() RouteManager
}
type MeshManagerImpl struct {
@ -44,18 +51,94 @@ type MeshManagerImpl struct {
idGenerator lib.IdGenerator
ipAllocator ip.IPAllocator
interfaceManipulator wg.WgInterfaceManipulator
Monitor MeshMonitor
}
// GetRouteManager implements MeshManager.
func (m *MeshManagerImpl) GetRouteManager() RouteManager {
return m.RouteManager
}
// RemoveService implements MeshManager.
func (m *MeshManagerImpl) RemoveService(service string) error {
for _, mesh := range m.Meshes {
err := mesh.RemoveService(m.HostParameters.GetPublicKey(), service)
if err != nil {
return err
}
}
return nil
}
// SetService implements MeshManager.
func (m *MeshManagerImpl) SetService(service string, value string) error {
for _, mesh := range m.Meshes {
err := mesh.AddService(m.HostParameters.GetPublicKey(), service, value)
if err != nil {
return err
}
}
return nil
}
func (m *MeshManagerImpl) GetNode(meshid, nodeId string) MeshNode {
mesh, ok := m.Meshes[meshid]
if !ok {
return nil
}
node, err := mesh.GetNode(nodeId)
if err != nil {
return nil
}
return node
}
// GetMonitor implements MeshManager.
func (m *MeshManagerImpl) GetMonitor() MeshMonitor {
return m.Monitor
}
// Prune implements MeshManager.
func (m *MeshManagerImpl) Prune() error {
for _, mesh := range m.Meshes {
err := mesh.Prune(m.conf.PruneTime)
if err != nil {
return err
}
}
return nil
}
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil {
return "", err
}
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port, m.HostParameters.PrivateKey)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName,
DevName: ifName,
Port: port,
Conf: m.conf,
Client: m.Client,
@ -63,39 +146,34 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
})
if err != nil {
return "", err
}
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", nil
return "", fmt.Errorf("error creating mesh: %w", err)
}
m.Meshes[meshId] = nodeManager
err = m.configApplyer.RemovePeers(meshId)
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
return meshId, nil
}
type AddMeshParams struct {
MeshId string
DevName string
WgPort int
MeshBytes []byte
}
// AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort, m.HostParameters.PrivateKey)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName,
DevName: ifName,
Port: params.WgPort,
Conf: m.conf,
Client: m.Client,
@ -113,11 +191,7 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
}
m.Meshes[params.MeshId] = meshProvider
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
return nil
}
// HasChanges returns true if the mesh has changes
@ -131,37 +205,13 @@ func (m *MeshManagerImpl) GetMesh(meshId string) MeshProvider {
return theMesh
}
// EnableInterface: Enables the given WireGuard interface.
func (s *MeshManagerImpl) EnableInterface(meshId string) error {
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
meshNode, err := s.GetSelf(meshId)
if err != nil {
return err
}
mesh := s.GetMesh(meshId)
if err != nil {
return err
}
dev, err := mesh.GetDevice()
if err != nil {
return err
}
return s.interfaceManipulator.EnableInterface(dev.Name, meshNode.GetWgHost().String())
}
// GetPublicKey: Gets the public key of the WireGuard mesh
func (s *MeshManagerImpl) GetPublicKey(meshId string) (*wgtypes.Key, error) {
if s.conf.StubWg {
zeroedKey := make([]byte, wgtypes.KeyLen)
return (*wgtypes.Key)(zeroedKey), nil
}
mesh, ok := s.Meshes[meshId]
if !ok {
@ -183,64 +233,100 @@ type AddSelfParams struct {
// WgPort is the WireGuard port to advertise
WgPort int
// Endpoint is the alias of the machine to send routable packets
// to
Endpoint string
}
// AddSelf adds this host to the mesh
func (s *MeshManagerImpl) AddSelf(params *AddSelfParams) error {
pubKey, err := s.GetPublicKey(params.MeshId)
mesh := s.GetMesh(params.MeshId)
if mesh == nil {
return fmt.Errorf("addself: mesh %s does not exist", params.MeshId)
}
if params.WgPort == 0 && !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
nodeIP, err := s.ipAllocator.GetIP(*pubKey, params.MeshId)
params.WgPort = device.ListenPort
}
pubKey := s.HostParameters.PrivateKey.PublicKey()
nodeIP, err := s.ipAllocator.GetIP(pubKey, params.MeshId)
if err != nil {
return err
}
node := s.nodeFactory.Build(&MeshNodeFactoryParams{
PublicKey: pubKey,
PublicKey: &pubKey,
NodeIP: nodeIP,
WgPort: params.WgPort,
Endpoint: params.Endpoint,
})
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return fmt.Errorf("failed to get device %w", err)
}
err = s.interfaceManipulator.AddAddress(device.Name, fmt.Sprintf("%s/64", nodeIP))
if err != nil {
return fmt.Errorf("addSelf: failed to add address to dev %w", err)
}
}
s.Meshes[params.MeshId].AddNode(node)
return s.RouteManager.UpdateRoutes()
}
// LeaveMesh leaves the mesh network
func (s *MeshManagerImpl) LeaveMesh(meshId string) error {
_, exists := s.Meshes[meshId]
mesh, exists := s.Meshes[meshId]
if !exists {
return errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return fmt.Errorf("mesh %s does not exist", meshId)
}
// For now just delete the mesh with the ID.
var err error
if !s.conf.StubWg {
device, err := mesh.GetDevice()
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(device.Name)
if err != nil {
return err
}
}
err = s.RouteManager.RemoveRoutes(meshId)
delete(s.Meshes, meshId)
return nil
return err
}
func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
meshInstance, ok := s.Meshes[meshId]
if !ok {
return nil, errors.New(fmt.Sprintf("mesh %s does not exist", meshId))
return nil, fmt.Errorf("mesh %s does not exist", meshId)
}
snapshot, err := meshInstance.GetMesh()
logging.Log.WriteInfof(s.HostParameters.GetPublicKey())
node, err := meshInstance.GetNode(s.HostParameters.GetPublicKey())
if err != nil {
return nil, err
}
node, ok := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if !ok {
return nil, errors.New("the node doesn't exist in the mesh")
}
@ -248,34 +334,52 @@ func (s *MeshManagerImpl) GetSelf(meshId string) (MeshNode, error) {
}
func (s *MeshManagerImpl) ApplyConfig() error {
return s.configApplyer.ApplyConfig()
if s.conf.StubWg {
return nil
}
err := s.configApplyer.ApplyConfig()
if err != nil {
return err
}
return nil
}
func (s *MeshManagerImpl) SetDescription(description string) error {
for _, mesh := range s.Meshes {
err := mesh.SetDescription(s.HostParameters.HostEndpoint, description)
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetDescription(s.HostParameters.GetPublicKey(), description)
if err != nil {
return err
}
}
}
return nil
}
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
// SetAlias implements MeshManager.
func (s *MeshManagerImpl) SetAlias(alias string) error {
for _, mesh := range s.Meshes {
snapshot, err := mesh.GetMesh()
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.SetAlias(s.HostParameters.GetPublicKey(), alias)
if err != nil {
return err
}
}
}
return nil
}
_, exists := snapshot.GetNodes()[s.HostParameters.HostEndpoint]
if exists {
err = mesh.UpdateTimeStamp(s.HostParameters.HostEndpoint)
// UpdateTimeStamp updates the timestamp of this node in all meshes
func (s *MeshManagerImpl) UpdateTimeStamp() error {
for _, mesh := range s.Meshes {
if mesh.NodeExists(s.HostParameters.GetPublicKey()) {
err := mesh.UpdateTimeStamp(s.HostParameters.GetPublicKey())
if err != nil {
return err
@ -294,6 +398,29 @@ func (s *MeshManagerImpl) GetMeshes() map[string]MeshProvider {
return s.Meshes
}
// Close the mesh manager
func (s *MeshManagerImpl) Close() error {
if s.conf.StubWg {
return nil
}
for _, mesh := range s.Meshes {
dev, err := mesh.GetDevice()
if err != nil {
return err
}
err = s.interfaceManipulator.RemoveInterface(dev.Name)
if err != nil {
return err
}
}
return nil
}
// NewMeshManagerParams params required to create an instance of a mesh manager
type NewMeshManagerParams struct {
Conf conf.WgMeshConfiguration
@ -304,21 +431,16 @@ type NewMeshManagerParams struct {
IPAllocator ip.IPAllocator
InterfaceManipulator wg.WgInterfaceManipulator
ConfigApplyer MeshConfigApplyer
RouteManager RouteManager
}
// Creates a new instance of a mesh manager with the given parameters
func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
hostParams := HostParameters{}
switch params.Conf.Endpoint {
case "":
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", lib.GetOutboundIP().String(), params.Conf.GrpcPort)
default:
hostParams.HostEndpoint = fmt.Sprintf("%s:%s", params.Conf.Endpoint, params.Conf.GrpcPort)
func NewMeshManager(params *NewMeshManagerParams) MeshManager {
privateKey, _ := wgtypes.GeneratePrivateKey()
hostParams := HostParameters{
PrivateKey: &privateKey,
}
logging.Log.WriteInfof("Endpoint %s", hostParams.HostEndpoint)
m := &MeshManagerImpl{
Meshes: make(map[string]MeshProvider),
HostParameters: &hostParams,
@ -327,10 +449,22 @@ func NewMeshManager(params *NewMeshManagerParams) *MeshManagerImpl {
Client: params.Client,
conf: &params.Conf,
}
m.configApplyer = params.ConfigApplyer
m.RouteManager = params.RouteManager
if m.RouteManager == nil {
m.RouteManager = NewRouteManager(m)
}
m.idGenerator = params.IdGenerator
m.ipAllocator = params.IPAllocator
m.interfaceManipulator = params.InterfaceManipulator
m.Monitor = NewMeshMonitor(m)
aliasManager := NewAliasManager()
m.Monitor.AddUpdateCallback(aliasManager.AddAliases)
m.Monitor.AddRemoveCallback(aliasManager.RemoveAliases)
return m
}

View File

@ -18,11 +18,11 @@ func getMeshConfiguration() *conf.WgMeshConfiguration {
BranchRate: 3,
InterClusterChance: 0.15,
InfectionCount: 2,
KeepAliveRate: 60,
KeepAliveTime: 60,
}
}
func getMeshManager() *MeshManagerImpl {
func getMeshManager() MeshManager {
manager := NewMeshManager(&NewMeshManagerParams{
Conf: *getMeshConfiguration(),
Client: nil,
@ -32,6 +32,7 @@ func getMeshManager() *MeshManagerImpl {
IPAllocator: &ip.ULABuilder{},
InterfaceManipulator: &wg.WgInterfaceManipulatorStub{},
ConfigApplyer: &MeshConfigApplyerStub{},
RouteManager: &RouteManagerStub{},
})
return manager
@ -50,7 +51,7 @@ func TestCreateMeshCreatesANewMeshProvider(t *testing.T) {
t.Fatal(`meshId should not be empty`)
}
_, exists := manager.Meshes[meshId]
_, exists := manager.GetMeshes()[meshId]
if !exists {
t.Fatal(`mesh was not created when it should be`)
@ -63,7 +64,6 @@ func TestAddMeshAddsAMesh(t *testing.T) {
manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -82,7 +82,6 @@ func TestAddMeshMeshAlreadyExistsReplacesIt(t *testing.T) {
for i := 0; i < 2; i++ {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -105,7 +104,6 @@ func TestAddSelfAddsSelfToTheMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -174,7 +172,6 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err := manager.AddMesh(&AddMeshParams{
MeshId: meshId,
DevName: "wg0",
WgPort: 6000,
MeshBytes: make([]byte, 0),
})
@ -186,10 +183,10 @@ func TestLeaveMeshDeletesMesh(t *testing.T) {
err = manager.LeaveMesh(meshId)
if err != nil {
t.Error(err)
t.Fatalf("%s", err.Error())
}
_, exists := manager.Meshes[meshId]
_, exists := manager.GetMeshes()[meshId]
if exists {
t.Fatalf(`expected mesh to have been deleted`)
@ -200,8 +197,8 @@ func TestSetDescription(t *testing.T) {
manager := getMeshManager()
description := "wooooo"
meshId1, _ := manager.CreateMesh("wg0", 5000)
meshId2, _ := manager.CreateMesh("wg0", 5001)
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,
@ -224,8 +221,8 @@ func TestSetDescription(t *testing.T) {
func TestUpdateTimeStampUpdatesAllMeshes(t *testing.T) {
manager := getMeshManager()
meshId1, _ := manager.CreateMesh("wg0", 5000)
meshId2, _ := manager.CreateMesh("wg0", 5001)
meshId1, _ := manager.CreateMesh(5000)
meshId2, _ := manager.CreateMesh(5001)
manager.AddSelf(&AddSelfParams{
MeshId: meshId1,

81
pkg/mesh/monitor.go Normal file
View File

@ -0,0 +1,81 @@
package mesh
type OnChange = func([]MeshNode)
type MeshMonitor interface {
AddUpdateCallback(cb OnChange)
AddRemoveCallback(cb OnChange)
Trigger() error
}
type MeshMonitorImpl struct {
updateCbs []OnChange
removeCbs []OnChange
nodes map[string]MeshNode
manager MeshManager
}
// Trigger causes the mesh monitor to trigger all of
// the callbacks.
func (m *MeshMonitorImpl) Trigger() error {
changedNodes := make([]MeshNode, 0)
removedNodes := make([]MeshNode, 0)
nodes := make(map[string]MeshNode)
for _, mesh := range m.manager.GetMeshes() {
snapshot, err := mesh.GetMesh()
if err != nil {
return err
}
for _, node := range snapshot.GetNodes() {
previous, exists := m.nodes[node.GetWgHost().String()]
if !exists || !NodeEquals(previous, node) {
changedNodes = append(changedNodes, node)
}
nodes[node.GetWgHost().String()] = node
}
}
for _, previous := range m.nodes {
_, ok := nodes[previous.GetWgHost().String()]
if !ok {
removedNodes = append(removedNodes, previous)
}
}
if len(removedNodes) > 0 {
for _, cb := range m.removeCbs {
cb(removedNodes)
}
}
if len(changedNodes) > 0 {
for _, cb := range m.updateCbs {
cb(changedNodes)
}
}
return nil
}
func (m *MeshMonitorImpl) AddUpdateCallback(cb OnChange) {
m.updateCbs = append(m.updateCbs, cb)
}
func (m *MeshMonitorImpl) AddRemoveCallback(cb OnChange) {
m.removeCbs = append(m.removeCbs, cb)
}
func NewMeshMonitor(manager MeshManager) MeshMonitor {
return &MeshMonitorImpl{
updateCbs: make([]OnChange, 0),
nodes: make(map[string]MeshNode),
manager: manager,
}
}

16
pkg/mesh/pruner.go Normal file
View File

@ -0,0 +1,16 @@
package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func pruneFunction(m MeshManager) lib.TimerFunc {
return func() error {
return m.Prune()
}
}
func NewPruner(m MeshManager, conf conf.WgMeshConfiguration) *lib.Timer {
return lib.NewTimer(pruneFunction(m), conf.PruneTime/2)
}

View File

@ -2,17 +2,17 @@ package mesh
import (
"github.com/tim-beatham/wgmesh/pkg/ip"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/route"
)
type RouteManager interface {
UpdateRoutes() error
RemoveRoutes(meshId string) error
}
type RouteManagerImpl struct {
meshManager MeshManager
routeInstaller route.RouteInstaller
}
func (r *RouteManagerImpl) UpdateRoutes() error {
@ -20,6 +20,24 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
ulaBuilder := new(ip.ULABuilder)
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
pubKey, err := self.GetPublicKey()
if err != nil {
return err
}
routes, err := mesh1.GetRoutes(pubKey.String())
if err != nil {
return err
}
for _, mesh2 := range meshes {
if mesh1 == mesh2 {
continue
@ -32,13 +50,11 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return err
}
self, err := r.meshManager.GetSelf(mesh1.GetMeshId())
if err != nil {
return err
}
err = mesh1.AddRoutes(self.GetHostEndpoint(), ipNet.String())
err = mesh2.AddRoutes(NodeID(self), append(lib.MapValues(routes), &RouteStub{
Destination: ipNet,
HopCount: 0,
Path: make([]string, 0),
})...)
if err != nil {
return err
@ -49,6 +65,29 @@ func (r *RouteManagerImpl) UpdateRoutes() error {
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m, routeInstaller: route.NewRouteInstaller()}
// removeRoutes: removes all meshes we are no longer a part of
func (r *RouteManagerImpl) RemoveRoutes(meshId string) error {
ulaBuilder := new(ip.ULABuilder)
meshes := r.meshManager.GetMeshes()
ipNet, err := ulaBuilder.GetIPNet(meshId)
if err != nil {
return err
}
for _, mesh1 := range meshes {
self, err := r.meshManager.GetSelf(meshId)
if err != nil {
return err
}
mesh1.RemoveRoutes(NodeID(self), ipNet.String())
}
return nil
}
func NewRouteManager(m MeshManager) RouteManager {
return &RouteManagerImpl{meshManager: m}
}

16
pkg/mesh/route_stub.go Normal file
View File

@ -0,0 +1,16 @@
package mesh
type RouteManagerStub struct {
}
func (r *RouteManagerStub) UpdateRoutes() error {
return nil
}
func (r *RouteManagerStub) InstallRoutes() error {
return nil
}
func (r *RouteManagerStub) RemoveRoutes(meshId string) error {
return nil
}

View File

@ -16,11 +16,26 @@ type MeshNodeStub struct {
wgEndpoint string
wgHost *net.IPNet
timeStamp int64
routes []string
routes []Route
identifier string
description string
}
// GetType implements MeshNode.
func (*MeshNodeStub) GetType() conf.NodeType {
return conf.PEER_ROLE
}
// GetServices implements MeshNode.
func (*MeshNodeStub) GetServices() map[string]string {
return make(map[string]string)
}
// GetAlias implements MeshNode.
func (*MeshNodeStub) GetAlias() string {
return ""
}
func (m *MeshNodeStub) GetHostEndpoint() string {
return m.hostEndpoint
}
@ -41,7 +56,7 @@ func (m *MeshNodeStub) GetTimeStamp() int64 {
return m.timeStamp
}
func (m *MeshNodeStub) GetRoutes() []string {
func (m *MeshNodeStub) GetRoutes() []Route {
return m.routes
}
@ -66,6 +81,50 @@ type MeshProviderStub struct {
snapshot *MeshSnapshotStub
}
func (*MeshProviderStub) GetRoutes(targetId string) (map[string]Route, error) {
return nil, nil
}
// GetNodeIds implements MeshProvider.
func (*MeshProviderStub) GetPeers() []string {
return make([]string, 0)
}
// GetNode implements MeshProvider.
func (*MeshProviderStub) GetNode(string) (MeshNode, error) {
return nil, nil
}
// NodeExists implements MeshProvider.
func (*MeshProviderStub) NodeExists(string) bool {
return false
}
// AddService implements MeshProvider.
func (*MeshProviderStub) AddService(nodeId string, key string, value string) error {
return nil
}
// RemoveService implements MeshProvider.
func (*MeshProviderStub) RemoveService(nodeId string, key string) error {
return nil
}
// SetAlias implements MeshProvider.
func (*MeshProviderStub) SetAlias(nodeId string, alias string) error {
panic("unimplemented")
}
// RemoveRoutes implements MeshProvider.
func (*MeshProviderStub) RemoveRoutes(nodeId string, route ...string) error {
return nil
}
// Prune implements MeshProvider.
func (*MeshProviderStub) Prune(pruneAmount int) error {
return nil
}
// UpdateTimeStamp implements MeshProvider.
func (*MeshProviderStub) UpdateTimeStamp(nodeId string) error {
return nil
@ -104,7 +163,7 @@ func (s *MeshProviderStub) HasChanges() bool {
return false
}
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...string) error {
func (s *MeshProviderStub) AddRoutes(nodeId string, route ...Route) error {
return nil
}
@ -138,7 +197,7 @@ func (s *StubNodeFactory) Build(params *MeshNodeFactoryParams) MeshNode {
wgEndpoint: fmt.Sprintf("%s:%s", params.Endpoint, s.Config.GrpcPort),
wgHost: wgHost,
timeStamp: time.Now().Unix(),
routes: make([]string, 0),
routes: make([]Route, 0),
identifier: "abc",
description: "A Mesh Node Stub",
}
@ -154,15 +213,58 @@ func (a *MeshConfigApplyerStub) RemovePeers(meshId string) error {
return nil
}
func (a *MeshConfigApplyerStub) SetMeshManager(manager MeshManager) {
}
type MeshManagerStub struct {
meshes map[string]MeshProvider
}
// GetRouteManager implements MeshManager.
func (*MeshManagerStub) GetRouteManager() RouteManager {
panic("unimplemented")
}
// GetNode implements MeshManager.
func (*MeshManagerStub) GetNode(string, string) MeshNode {
panic("unimplemented")
}
// RemoveService implements MeshManager.
func (*MeshManagerStub) RemoveService(service string) error {
panic("unimplemented")
}
// SetService implements MeshManager.
func (*MeshManagerStub) SetService(service string, value string) error {
panic("unimplemented")
}
// GetMonitor implements MeshManager.
func (*MeshManagerStub) GetMonitor() MeshMonitor {
panic("unimplemented")
}
// SetAlias implements MeshManager.
func (*MeshManagerStub) SetAlias(alias string) error {
panic("unimplemented")
}
// Close implements MeshManager.
func (*MeshManagerStub) Close() error {
panic("unimplemented")
}
// Prune implements MeshManager.
func (*MeshManagerStub) Prune() error {
return nil
}
func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
}
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) {
func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil
}
@ -185,10 +287,6 @@ func (m *MeshManagerStub) GetMesh(meshId string) MeshProvider {
snapshot: &MeshSnapshotStub{nodes: make(map[string]MeshNode)}}
}
func (m *MeshManagerStub) EnableInterface(meshId string) error {
return nil
}
func (m *MeshManagerStub) GetPublicKey(meshId string) (*wgtypes.Key, error) {
key, _ := wgtypes.GenerateKey()
return &key, nil

View File

@ -10,6 +10,33 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
type Route interface {
// GetDestination: returns the destination of the route
GetDestination() *net.IPNet
// GetHopCount: get the total hopcount of the prefix
GetHopCount() int
// GetPath: get a list of AS paths to get to the destination
GetPath() []string
}
type RouteStub struct {
Destination *net.IPNet
HopCount int
Path []string
}
func (r *RouteStub) GetDestination() *net.IPNet {
return r.Destination
}
func (r *RouteStub) GetHopCount() int {
return r.HopCount
}
func (r *RouteStub) GetPath() []string {
return r.Path
}
// MeshNode represents an implementation of a node in a mesh
type MeshNode interface {
// GetHostEndpoint: gets the gRPC endpoint of the node
@ -23,11 +50,35 @@ type MeshNode interface {
// GetTimestamp: get the UNIX time stamp of the ndoe
GetTimeStamp() int64
// GetRoutes: returns the routes that the nodes provides
GetRoutes() []string
GetRoutes() []Route
// GetIdentifier: returns the identifier of the node
GetIdentifier() string
// GetDescription: returns the description for this node
GetDescription() string
// GetAlias: associates the node with an alias. Potentially used
// for DNS and so forth.
GetAlias() string
// GetServices: returns a list of services offered by the node
GetServices() map[string]string
GetType() conf.NodeType
}
// NodeEquals: determines if two mesh nodes are equivalent to one another
func NodeEquals(node1, node2 MeshNode) bool {
key1, _ := node1.GetPublicKey()
key2, _ := node2.GetPublicKey()
return key1.String() == key2.String()
}
func RouteEquals(route1, route2 Route) bool {
return route1.GetDestination().String() == route2.GetDestination().String() &&
route1.GetHopCount() == route2.GetHopCount()
}
func NodeID(node MeshNode) string {
key, _ := node.GetPublicKey()
return key.String()
}
type MeshSnapshot interface {
@ -46,7 +97,7 @@ type MeshSyncer interface {
type MeshProvider interface {
// AddNode() adds a node to the mesh
AddNode(node MeshNode)
// GetMesh() returns a snapshot of the mesh provided by the mesh provider
// GetMesh() returns a snapshot of the mesh provided by the mesh provider.
GetMesh() (MeshSnapshot, error)
// GetMeshId() returns the ID of the mesh network
GetMeshId() string
@ -58,20 +109,45 @@ type MeshProvider interface {
GetDevice() (*wgtypes.Device, error)
// HasChanges returns true if we have changes since last time we synced
HasChanges() bool
// Record that we have changges and save the corresponding changes
// Record that we have changes and save the corresponding changes
SaveChanges()
// UpdateTimeStamp: update the timestamp of the given node
UpdateTimeStamp(nodeId string) error
// AddRoutes: adds routes to the given node
AddRoutes(nodeId string, route ...string) error
AddRoutes(nodeId string, route ...Route) error
// DeleteRoutes: deletes the routes from the node
RemoveRoutes(nodeId string, route ...string) error
// GetSyncer: returns the automerge syncer for sync
GetSyncer() MeshSyncer
// GetNode get a particular not within the mesh
GetNode(string) (MeshNode, error)
// NodeExists: returns true if a particular node exists false otherwise
NodeExists(string) bool
// SetDescription: sets the description of this automerge data type
SetDescription(nodeId string, description string) error
// SetAlias: set the alias of the nodeId
SetAlias(nodeId string, alias string) error
// AddService: adds the service to the given node
AddService(nodeId, key, value string) error
// RemoveService: removes the service form the node. throws an error if the service does not exist
RemoveService(nodeId, key string) error
// Prune: prunes all nodes that have not updated their timestamp in
// pruneAmount seconds
Prune(pruneAmount int) error
// GetPeers: get a list of contactable peers
GetPeers() []string
// GetRoutes(): Get all unique routes. Where the route with the least hop count is chosen
GetRoutes(targetNode string) (map[string]Route, error)
}
// HostParameters contains the IDs of a node
type HostParameters struct {
HostEndpoint string
// TODO: Contain the WireGungracefullyuard identifier in this
PrivateKey *wgtypes.Key
}
// GetPublicKey: gets the public key of the node
func (h *HostParameters) GetPublicKey() string {
return h.PrivateKey.PublicKey().String()
}
// MeshProviderFactoryParams parameters required to build a mesh provider

View File

@ -3,8 +3,10 @@ package query
import (
"encoding/json"
"fmt"
"strings"
"github.com/jmespath/go-jmespath"
"github.com/tim-beatham/wgmesh/pkg/conf"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
@ -23,14 +25,23 @@ type QueryError struct {
msg string
}
type QueryRoute struct {
Destination string `json:"destination"`
HopCount int `json:"hopCount"`
Path string `json:"path"`
}
type QueryNode struct {
HostEndpoint string `json:"hostEndpoint"`
PublicKey string `json:"publicKey"`
WgEndpoint string `json:"wgEndpoint"`
WgHost string `json:"wgHost"`
Timestamp int64 `json:"timestmap"`
Timestamp int64 `json:"timestamp"`
Description string `json:"description"`
Routes []string `json:"routes"`
Routes []QueryRoute `json:"routes"`
Alias string `json:"alias"`
Services map[string]string `json:"services"`
Type conf.NodeType `json:"type"`
}
func (m *QueryError) Error() string {
@ -51,7 +62,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return nil, err
}
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), meshNodeToQueryNode)
nodes := lib.Map(lib.MapValues(snapshot.GetNodes()), MeshNodeToQueryNode)
result, err := jmespath.Search(queryParams, nodes)
@ -63,7 +74,7 @@ func (j *JmesQuerier) Query(meshId, queryParams string) ([]byte, error) {
return bytes, err
}
func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
func MeshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode := new(QueryNode)
queryNode.HostEndpoint = node.GetHostEndpoint()
pubKey, _ := node.GetPublicKey()
@ -74,8 +85,18 @@ func meshNodeToQueryNode(node mesh.MeshNode) *QueryNode {
queryNode.WgHost = node.GetWgHost().String()
queryNode.Timestamp = node.GetTimeStamp()
queryNode.Routes = node.GetRoutes()
queryNode.Routes = lib.Map(node.GetRoutes(), func(r mesh.Route) QueryRoute {
return QueryRoute{
Destination: r.GetDestination().String(),
HopCount: r.GetHopCount(),
Path: strings.Join(r.GetPath(), ","),
}
})
queryNode.Description = node.GetDescription()
queryNode.Alias = node.GetAlias()
queryNode.Services = node.GetServices()
queryNode.Type = node.GetType()
return queryNode
}

View File

@ -2,6 +2,7 @@ package robin
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
@ -9,7 +10,9 @@ import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/ipc"
"github.com/tim-beatham/wgmesh/pkg/lib"
"github.com/tim-beatham/wgmesh/pkg/mesh"
"github.com/tim-beatham/wgmesh/pkg/query"
"github.com/tim-beatham/wgmesh/pkg/rpc"
)
@ -18,7 +21,7 @@ type IpcHandler struct {
}
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort)
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil {
return err
@ -30,6 +33,10 @@ func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
Endpoint: args.Endpoint,
})
if err != nil {
return err
}
*reply = meshId
return err
}
@ -66,7 +73,9 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
configuration := n.Server.GetConfiguration()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(configuration.Timeout))
defer cancel()
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
@ -77,7 +86,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port,
MeshBytes: meshReply.Mesh,
})
@ -112,16 +120,22 @@ func (n *IpcHandler) LeaveMesh(meshId string, reply *string) error {
}
func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
mesh := n.Server.GetMeshManager().GetMesh(meshId)
meshSnapshot, err := mesh.GetMesh()
theMesh := n.Server.GetMeshManager().GetMesh(meshId)
if theMesh == nil {
return fmt.Errorf("mesh %s does not exist", meshId)
}
meshSnapshot, err := theMesh.GetMesh()
if err != nil {
return err
}
if mesh == nil {
if theMesh == nil {
return errors.New("mesh does not exist")
}
nodes := make([]ctrlserver.MeshNode, len(meshSnapshot.GetNodes()))
i := 0
@ -138,7 +152,15 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
PublicKey: pubKey.String(),
WgHost: node.GetWgHost().String(),
Timestamp: node.GetTimeStamp(),
Routes: node.GetRoutes(),
Routes: lib.Map(node.GetRoutes(), func(r mesh.Route) ctrlserver.MeshRoute {
return ctrlserver.MeshRoute{
Destination: r.GetDestination().String(),
Path: r.GetPath(),
}
}),
Description: node.GetDescription(),
Alias: node.GetAlias(),
Services: node.GetServices(),
}
nodes[i] = node
@ -149,18 +171,6 @@ func (n *IpcHandler) GetMesh(meshId string, reply *ipc.GetMeshReply) error {
return nil
}
func (n *IpcHandler) EnableInterface(meshId string, reply *string) error {
err := n.Server.GetMeshManager().EnableInterface(meshId)
if err != nil {
*reply = err.Error()
return err
}
*reply = "up"
return nil
}
func (n *IpcHandler) GetDOT(meshId string, reply *string) error {
g := mesh.NewMeshDotConverter(n.Server.GetMeshManager())
@ -196,6 +206,60 @@ func (n *IpcHandler) PutDescription(description string, reply *string) error {
return nil
}
func (n *IpcHandler) PutAlias(alias string, reply *string) error {
err := n.Server.GetMeshManager().SetAlias(alias)
if err != nil {
return err
}
*reply = fmt.Sprintf("Set alias to %s", alias)
return nil
}
func (n *IpcHandler) PutService(service ipc.PutServiceArgs, reply *string) error {
err := n.Server.GetMeshManager().SetService(service.Service, service.Value)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) DeleteService(service string, reply *string) error {
err := n.Server.GetMeshManager().RemoveService(service)
if err != nil {
return err
}
*reply = "success"
return nil
}
func (n *IpcHandler) GetNode(args ipc.GetNodeArgs, reply *string) error {
node := n.Server.GetMeshManager().GetNode(args.MeshId, args.NodeId)
if node == nil {
*reply = "nil"
return nil
}
queryNode := query.MeshNodeToQueryNode(node)
bytes, err := json.Marshal(queryNode)
if err != nil {
*reply = err.Error()
return nil
}
*reply = string(bytes)
return nil
}
type RobinIpcParams struct {
CtrlServer ctrlserver.CtrlServer
}

View File

@ -28,7 +28,3 @@ func (m *WgRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*rpc.
return &reply, nil
}
func (m *WgRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
return &rpc.JoinMeshReply{Success: true}, nil
}

View File

@ -1,22 +1,32 @@
package route
import (
"net"
"os/exec"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
"golang.org/x/sys/unix"
)
type RouteInstaller interface {
InstallRoutes(devName string, routes ...*net.IPNet) error
InstallRoutes(devName string, routes ...lib.Route) error
}
type RouteInstallerImpl struct{}
// InstallRoutes: installs a route into the routing table
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet) error {
func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...lib.Route) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
}
err = rtnl.DeleteRoutes(devName, unix.AF_INET6, routes...)
if err != nil {
return err
}
for _, route := range routes {
err := r.installRoute(devName, route)
err := rtnl.AddRoute(devName, route)
if err != nil {
return err
@ -26,22 +36,6 @@ func (r *RouteInstallerImpl) InstallRoutes(devName string, routes ...*net.IPNet)
return nil
}
// installRoute: installs a route into the linux table
func (r *RouteInstallerImpl) installRoute(devName string, route *net.IPNet) error {
// TODO: Find a library that automates this
cmd := exec.Command("/usr/bin/ip", "-6", "route", "add", route.String(), "dev", devName)
logging.Log.WriteInfof("%s %s", route.String(), devName)
if msg, err := cmd.CombinedOutput(); err != nil {
logging.Log.WriteErrorf(err.Error())
logging.Log.WriteErrorf(string(msg))
return err
}
return nil
}
func NewRouteInstaller() RouteInstaller {
return &RouteInstallerImpl{}
}

View File

@ -20,77 +20,6 @@ const (
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type MeshNode struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PublicKey string `protobuf:"bytes,1,opt,name=publicKey,proto3" json:"publicKey,omitempty"`
WgEndpoint string `protobuf:"bytes,2,opt,name=wgEndpoint,proto3" json:"wgEndpoint,omitempty"`
Endpoint string `protobuf:"bytes,3,opt,name=endpoint,proto3" json:"endpoint,omitempty"`
WgHost string `protobuf:"bytes,4,opt,name=wgHost,proto3" json:"wgHost,omitempty"`
}
func (x *MeshNode) Reset() {
*x = MeshNode{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *MeshNode) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*MeshNode) ProtoMessage() {}
func (x *MeshNode) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use MeshNode.ProtoReflect.Descriptor instead.
func (*MeshNode) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *MeshNode) GetPublicKey() string {
if x != nil {
return x.PublicKey
}
return ""
}
func (x *MeshNode) GetWgEndpoint() string {
if x != nil {
return x.WgEndpoint
}
return ""
}
func (x *MeshNode) GetEndpoint() string {
if x != nil {
return x.Endpoint
}
return ""
}
func (x *MeshNode) GetWgHost() string {
if x != nil {
return x.WgHost
}
return ""
}
type GetMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
@ -102,7 +31,7 @@ type GetMeshRequest struct {
func (x *GetMeshRequest) Reset() {
*x = GetMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -115,7 +44,7 @@ func (x *GetMeshRequest) String() string {
func (*GetMeshRequest) ProtoMessage() {}
func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -128,7 +57,7 @@ func (x *GetMeshRequest) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshRequest.ProtoReflect.Descriptor instead.
func (*GetMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{0}
}
func (x *GetMeshRequest) GetMeshId() string {
@ -149,7 +78,7 @@ type GetMeshReply struct {
func (x *GetMeshReply) Reset() {
*x = GetMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@ -162,7 +91,7 @@ func (x *GetMeshReply) String() string {
func (*GetMeshReply) ProtoMessage() {}
func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2]
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@ -175,7 +104,7 @@ func (x *GetMeshReply) ProtoReflect() protoreflect.Message {
// Deprecated: Use GetMeshReply.ProtoReflect.Descriptor instead.
func (*GetMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{2}
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{1}
}
func (x *GetMeshReply) GetMesh() []byte {
@ -185,145 +114,24 @@ func (x *GetMeshReply) GetMesh() []byte {
return nil
}
type JoinMeshRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Changes []byte `protobuf:"bytes,1,opt,name=changes,proto3" json:"changes,omitempty"`
MeshId string `protobuf:"bytes,2,opt,name=meshId,proto3" json:"meshId,omitempty"`
}
func (x *JoinMeshRequest) Reset() {
*x = JoinMeshRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshRequest) ProtoMessage() {}
func (x *JoinMeshRequest) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshRequest.ProtoReflect.Descriptor instead.
func (*JoinMeshRequest) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{3}
}
func (x *JoinMeshRequest) GetChanges() []byte {
if x != nil {
return x.Changes
}
return nil
}
func (x *JoinMeshRequest) GetMeshId() string {
if x != nil {
return x.MeshId
}
return ""
}
type JoinMeshReply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *JoinMeshReply) Reset() {
*x = JoinMeshReply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *JoinMeshReply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*JoinMeshReply) ProtoMessage() {}
func (x *JoinMeshReply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use JoinMeshReply.ProtoReflect.Descriptor instead.
func (*JoinMeshReply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP(), []int{4}
}
func (x *JoinMeshReply) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
var File_pkg_grpc_ctrlserver_ctrlserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc = []byte{
0x0a, 0x24, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x22, 0x7c, 0x0a, 0x08, 0x4d, 0x65, 0x73, 0x68, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1c, 0x0a, 0x09,
0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52,
0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x77, 0x67,
0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a,
0x77, 0x67, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x65, 0x6e,
0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74,
0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x77, 0x67, 0x48, 0x6f, 0x73, 0x74, 0x22, 0x28,
0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74,
0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09,
0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65, 0x73, 0x68,
0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x22, 0x43, 0x0a, 0x0f,
0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12,
0x18, 0x0a, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c,
0x52, 0x07, 0x63, 0x68, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49,
0x64, 0x22, 0x29, 0x0a, 0x0d, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20,
0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, 0x91, 0x01, 0x0a,
0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12,
0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71,
0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e,
0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x40,
0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x19, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00,
0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
0x22, 0x28, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65,
0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x22, 0x22, 0x0a, 0x0c, 0x47, 0x65,
0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x65,
0x73, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x6d, 0x65, 0x73, 0x68, 0x32, 0x4f,
0x0a, 0x0e, 0x4d, 0x65, 0x73, 0x68, 0x43, 0x74, 0x72, 0x6c, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72,
0x12, 0x3d, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x18, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, 0x70, 0x65, 0x73,
0x2e, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42,
0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x33,
}
var (
@ -338,21 +146,16 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescGZIP() []byte {
return file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDescData
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_pkg_grpc_ctrlserver_ctrlserver_proto_goTypes = []interface{}{
(*MeshNode)(nil), // 0: rpctypes.MeshNode
(*GetMeshRequest)(nil), // 1: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 2: rpctypes.GetMeshReply
(*JoinMeshRequest)(nil), // 3: rpctypes.JoinMeshRequest
(*JoinMeshReply)(nil), // 4: rpctypes.JoinMeshReply
(*GetMeshRequest)(nil), // 0: rpctypes.GetMeshRequest
(*GetMeshReply)(nil), // 1: rpctypes.GetMeshReply
}
var file_pkg_grpc_ctrlserver_ctrlserver_proto_depIdxs = []int32{
1, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
3, // 1: rpctypes.MeshCtrlServer.JoinMesh:input_type -> rpctypes.JoinMeshRequest
2, // 2: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
4, // 3: rpctypes.MeshCtrlServer.JoinMesh:output_type -> rpctypes.JoinMeshReply
2, // [2:4] is the sub-list for method output_type
0, // [0:2] is the sub-list for method input_type
0, // 0: rpctypes.MeshCtrlServer.GetMesh:input_type -> rpctypes.GetMeshRequest
1, // 1: rpctypes.MeshCtrlServer.GetMesh:output_type -> rpctypes.GetMeshReply
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
@ -365,18 +168,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*MeshNode); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshRequest); i {
case 0:
return &v.state
@ -388,7 +179,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*GetMeshReply); i {
case 0:
return &v.state
@ -400,30 +191,6 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_ctrlserver_ctrlserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*JoinMeshReply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -431,7 +198,7 @@ func file_pkg_grpc_ctrlserver_ctrlserver_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_ctrlserver_ctrlserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 5,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},

View File

@ -23,7 +23,6 @@ const _ = grpc.SupportPackageIsVersion7
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type MeshCtrlServerClient interface {
GetMesh(ctx context.Context, in *GetMeshRequest, opts ...grpc.CallOption) (*GetMeshReply, error)
JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error)
}
type meshCtrlServerClient struct {
@ -43,21 +42,11 @@ func (c *meshCtrlServerClient) GetMesh(ctx context.Context, in *GetMeshRequest,
return out, nil
}
func (c *meshCtrlServerClient) JoinMesh(ctx context.Context, in *JoinMeshRequest, opts ...grpc.CallOption) (*JoinMeshReply, error) {
out := new(JoinMeshReply)
err := c.cc.Invoke(ctx, "/rpctypes.MeshCtrlServer/JoinMesh", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// MeshCtrlServerServer is the server API for MeshCtrlServer service.
// All implementations must embed UnimplementedMeshCtrlServerServer
// for forward compatibility
type MeshCtrlServerServer interface {
GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error)
JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error)
mustEmbedUnimplementedMeshCtrlServerServer()
}
@ -68,9 +57,6 @@ type UnimplementedMeshCtrlServerServer struct {
func (UnimplementedMeshCtrlServerServer) GetMesh(context.Context, *GetMeshRequest) (*GetMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) JoinMesh(context.Context, *JoinMeshRequest) (*JoinMeshReply, error) {
return nil, status.Errorf(codes.Unimplemented, "method JoinMesh not implemented")
}
func (UnimplementedMeshCtrlServerServer) mustEmbedUnimplementedMeshCtrlServerServer() {}
// UnsafeMeshCtrlServerServer may be embedded to opt out of forward compatibility for this service.
@ -102,24 +88,6 @@ func _MeshCtrlServer_GetMesh_Handler(srv interface{}, ctx context.Context, dec f
return interceptor(ctx, in, info, handler)
}
func _MeshCtrlServer_JoinMesh_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(JoinMeshRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/rpctypes.MeshCtrlServer/JoinMesh",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(MeshCtrlServerServer).JoinMesh(ctx, req.(*JoinMeshRequest))
}
return interceptor(ctx, in, info, handler)
}
// MeshCtrlServer_ServiceDesc is the grpc.ServiceDesc for MeshCtrlServer service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@ -131,10 +99,6 @@ var MeshCtrlServer_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetMesh",
Handler: _MeshCtrlServer_GetMesh_Handler,
},
{
MethodName: "JoinMesh",
Handler: _MeshCtrlServer_JoinMesh_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "pkg/grpc/ctrlserver/ctrlserver.proto",

View File

@ -1,7 +1,6 @@
package sync
import (
"errors"
"math/rand"
"sync"
"time"
@ -30,50 +29,35 @@ type SyncerImpl struct {
// Sync: Sync random nodes
func (s *SyncerImpl) Sync(meshId string) error {
logging.Log.WriteInfof("UPDATING WG CONF")
s.manager.ApplyConfig()
if !s.manager.HasChanges(meshId) && s.infectionCount == 0 {
logging.Log.WriteInfof("No changes for %s", meshId)
return nil
}
theMesh := s.manager.GetMesh(meshId)
logging.Log.WriteInfof("UPDATING WG CONF")
if theMesh == nil {
return errors.New("the provided mesh does not exist")
}
snapshot, err := theMesh.GetMesh()
if s.manager.HasChanges(meshId) {
err := s.manager.ApplyConfig()
if err != nil {
return err
}
nodes := snapshot.GetNodes()
if len(nodes) <= 1 {
return nil
logging.Log.WriteInfof("Failed to update config %w", err)
}
}
nodeNames := s.manager.GetMesh(meshId).GetPeers()
self, err := s.manager.GetSelf(meshId)
if err != nil {
return err
}
excludedNodes := map[string]struct{}{
self.GetHostEndpoint(): {},
}
meshNodes := lib.MapValuesWithExclude(nodes, excludedNodes)
selfPublickey, err := self.GetPublicKey()
getNames := func(node mesh.MeshNode) string {
return node.GetHostEndpoint()
if err != nil {
return err
}
nodeNames := lib.Map(meshNodes, getNames)
neighbours := s.cluster.GetNeighbours(nodeNames, self.GetHostEndpoint())
neighbours := s.cluster.GetNeighbours(nodeNames, selfPublickey.String())
randomSubset := lib.RandomSubsetOfLength(neighbours, s.conf.BranchRate)
for _, node := range randomSubset {
@ -82,9 +66,9 @@ func (s *SyncerImpl) Sync(meshId string) error {
before := time.Now()
if len(meshNodes) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
if len(nodeNames) > s.conf.ClusterSize && rand.Float64() < s.conf.InterClusterChance {
logging.Log.WriteInfof("Sending to random cluster")
interCluster := s.cluster.GetInterCluster(nodeNames, self.GetHostEndpoint())
interCluster := s.cluster.GetInterCluster(nodeNames, selfPublickey.String())
randomSubset = append(randomSubset, interCluster)
}
@ -95,7 +79,14 @@ func (s *SyncerImpl) Sync(meshId string) error {
go func(i int) error {
defer waitGroup.Done()
err := s.requester.SyncMesh(meshId, randomSubset[i])
correspondingPeer := s.manager.GetNode(meshId, randomSubset[i])
if correspondingPeer == nil {
logging.Log.WriteErrorf("node %s does not exist", randomSubset[i])
}
err := s.requester.SyncMesh(meshId, correspondingPeer.GetHostEndpoint())
return err
}(index)
}
@ -103,17 +94,20 @@ func (s *SyncerImpl) Sync(meshId string) error {
waitGroup.Wait()
s.syncCount++
logging.Log.WriteInfof("SYNC TIME: %v", time.Now().Sub(before))
logging.Log.WriteInfof("SYNC TIME: %v", time.Since(before))
logging.Log.WriteInfof("SYNC COUNT: %d", s.syncCount)
s.infectionCount = ((s.conf.InfectionCount + s.infectionCount - 1) % s.conf.InfectionCount)
// Check if any changes have occurred and trigger callbacks
// if changes have occurred.
// return s.manager.GetMonitor().Trigger()
return nil
}
// SyncMeshes: Sync all meshes
func (s *SyncerImpl) SyncMeshes() error {
for meshId, _ := range s.manager.GetMeshes() {
for meshId := range s.manager.GetMeshes() {
err := s.Sync(meshId)
if err != nil {

View File

@ -24,6 +24,13 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri
return false
}
// self, err := s.meshManager.GetSelf(meshId)
// if err != nil {
// return false
// }
// mesh.DecrementHealth(endpoint, self.GetHostEndpoint())
return true
}

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId,
DevName: ifName,
WgPort: port,
MeshBytes: reply.Mesh,
})
@ -89,10 +88,12 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
c := rpc.NewSyncServiceClient(client)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
syncTimeOut := s.server.Conf.SyncRate * float64(time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(syncTimeOut))
defer cancel()
err = syncMesh(mesh, ctx, c)
err = s.syncMesh(mesh, ctx, c)
if err != nil {
return s.handleErr(meshId, endpoint, err)
@ -102,7 +103,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
return nil
}
func syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
func (s *SyncRequesterImpl) syncMesh(mesh mesh.MeshProvider, ctx context.Context, client rpc.SyncServiceClient) error {
stream, err := client.SyncMesh(ctx)
syncer := mesh.GetSyncer()

View File

@ -1,55 +1,18 @@
package sync
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
// SyncScheduler: Loops through all nodes in the mesh and runs a schedule to
// sync each event
type SyncScheduler interface {
Run() error
Stop() error
}
// SyncSchedulerImpl scheduler for sync scheduling
type SyncSchedulerImpl struct {
quit chan struct{}
server *ctrlserver.MeshCtrlServer
syncer Syncer
}
// Run implements SyncScheduler.
func (s *SyncSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.server.Conf.SyncRate) * time.Second)
quit := make(chan struct{})
s.quit = quit
for {
select {
case <-ticker.C:
err := s.syncer.SyncMeshes()
if err != nil {
logging.Log.WriteErrorf(err.Error())
}
break
case <-quit:
break
}
func syncFunction(syncer Syncer) lib.TimerFunc {
return func() error {
return syncer.SyncMeshes()
}
}
// Stop implements SyncScheduler.
func (s *SyncSchedulerImpl) Stop() error {
close(s.quit)
return nil
}
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) SyncScheduler {
func NewSyncScheduler(s *ctrlserver.MeshCtrlServer, syncRequester SyncRequester) *lib.Timer {
syncer := NewSyncer(s.MeshManager, s.Conf, syncRequester)
return &SyncSchedulerImpl{server: s, syncer: syncer}
return lib.NewTimer(syncFunction(syncer), int(s.Conf.SyncRate))
}

View File

@ -64,11 +64,11 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
syncer = mesh.GetSyncer()
} else if meshId != in.MeshId {
return errors.New("Differing MeshIDs")
return errors.New("differing meshids")
}
if syncer == nil {
return errors.New("Syncer should not be nil")
return errors.New("syncer should not be nil")
}
msg, moreMessages := syncer.GenerateMessage()

22
pkg/timers/timers.go Normal file
View File

@ -0,0 +1,22 @@
package timer
import (
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
"github.com/tim-beatham/wgmesh/pkg/lib"
)
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.UpdateTimeStamp()
}
return *lib.NewTimer(timerFunc, ctrlServer.Conf.KeepAliveTime)
}
func NewRouteScheduler(ctrlServer *ctrlserver.MeshCtrlServer) lib.Timer {
timerFunc := func() error {
return ctrlServer.MeshManager.GetRouteManager().UpdateRoutes()
}
return *lib.NewTimer(timerFunc, 10)
}

View File

@ -1,51 +0,0 @@
package timestamp
import (
"time"
"github.com/tim-beatham/wgmesh/pkg/ctrlserver"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"github.com/tim-beatham/wgmesh/pkg/mesh"
)
type TimestampScheduler interface {
Run() error
Stop() error
}
type TimeStampSchedulerImpl struct {
meshManager mesh.MeshManager
updateRate int
quit chan struct{}
}
func (s *TimeStampSchedulerImpl) Run() error {
ticker := time.NewTicker(time.Duration(s.updateRate) * time.Second)
s.quit = make(chan struct{})
for {
select {
case <-ticker.C:
err := s.meshManager.UpdateTimeStamp()
if err != nil {
logging.Log.WriteErrorf("Update Timestamp Error: %s", err.Error())
}
case <-s.quit:
break
}
}
}
func NewTimestampScheduler(ctrlServer *ctrlserver.MeshCtrlServer) TimestampScheduler {
return &TimeStampSchedulerImpl{
meshManager: ctrlServer.MeshManager,
updateRate: ctrlServer.Conf.KeepAliveRate,
}
}
func (s *TimeStampSchedulerImpl) Stop() error {
close(s.quit)
return nil
}

View File

@ -2,10 +2,14 @@ package wg
type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error {
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) (string, error) {
return "", nil
}
func (i *WgInterfaceManipulatorStub) AddAddress(ifName string, addr string) error {
return nil
}
func (i *WgInterfaceManipulatorStub) EnableInterface(ifName string, ip string) error {
func (i *WgInterfaceManipulatorStub) RemoveInterface(ifName string) error {
return nil
}

View File

@ -1,5 +1,7 @@
package wg
import "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
type WgError struct {
msg string
}
@ -8,15 +10,11 @@ func (m *WgError) Error() string {
return m.msg
}
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error
// Enable interface enables the given interface with
// the IP. It overrides the IP at the interface
EnableInterface(ifName string, ip string) error
CreateInterface(port int, privateKey *wgtypes.Key) (string, error)
// AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface
RemoveInterface(ifName string) error
}

View File

@ -1,110 +1,93 @@
package wg
import (
"errors"
"crypto"
"crypto/rand"
"encoding/base64"
"fmt"
"net"
"os/exec"
"github.com/tim-beatham/wgmesh/pkg/lib"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// createInterface uses ip link to create an interface. If the interface exists
// it returns an error
func createInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
if err == nil {
err = flushInterface(ifName)
return err
}
// Check if the interface exists
cmd := exec.Command("/usr/bin/ip", "link", "add", "dev", ifName, "type", "wireguard")
if err := cmd.Run(); err != nil {
return err
}
return nil
}
type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client
}
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error {
err := createInterface(params.IfName)
const hashLength = 6
// CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int, privKey *wgtypes.Key) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return "", fmt.Errorf("failed to access link: %w", err)
}
defer rtnl.Close()
randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil {
return "", err
}
privateKey, err := wgtypes.GeneratePrivateKey()
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil {
return err
return "", fmt.Errorf("failed to create link: %w", err)
}
var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey,
ListenPort: &params.Port,
PrivateKey: privKey,
ListenPort: &port,
}
m.client.ConfigureDevice(params.IfName, cfg)
return nil
}
// flushInterface flushes the specified interface
func flushInterface(ifName string) error {
_, err := net.InterfaceByName(ifName)
err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil {
return &WgError{msg: fmt.Sprintf("Interface %s does not exist cannot flush", ifName)}
m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
}
cmd := exec.Command("/usr/bin/ip", "addr", "flush", "dev", ifName)
if err := cmd.Run(); err != nil {
logging.Log.WriteErrorf(fmt.Sprintf("%s error flushing interface %s", err.Error(), ifName))
return &WgError{msg: fmt.Sprintf("Failed to flush interface %s", ifName)}
logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return md5Str, nil
}
return nil
}
// EnableInterface flushes the interface and sets the ip address of the
// interface
func (m *WgInterfaceManipulatorImpl) EnableInterface(ifName string, ip string) error {
if len(ifName) == 0 {
return errors.New("ifName not provided")
}
err := flushInterface(ifName)
// Add an address to the given interface
func (m *WgInterfaceManipulatorImpl) AddAddress(ifName string, addr string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err != nil {
return err
return fmt.Errorf("failed to create config: %w", err)
}
defer rtnl.Close()
cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName)
if err := cmd.Run(); err != nil {
return err
}
hostIp, _, err := net.ParseCIDR(ip)
err = rtnl.AddAddress(ifName, addr)
if err != nil {
err = fmt.Errorf("failed to add address: %w", err)
}
return err
}
cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", ifName)
// RemoveInterface implements WgInterfaceManipulator.
func (*WgInterfaceManipulatorImpl) RemoveInterface(ifName string) error {
rtnl, err := lib.NewRtNetlinkConfig()
if err := cmd.Run(); err != nil {
return err
if err != nil {
return fmt.Errorf("failed to create config: %w", err)
}
return nil
defer rtnl.Close()
return rtnl.DeleteLink(ifName)
}
func NewWgInterfaceManipulator(client *wgctrl.Client) WgInterfaceManipulator {

View File

@ -0,0 +1,92 @@
// Package to convert an IPV6 addres into 8 words
package what8words
import (
"bufio"
"bytes"
"fmt"
"net"
"os"
"strings"
)
type What8Words struct {
words []string
}
// Convert implements What8Words.
func (w *What8Words) Convert(ipStr string) (string, error) {
ip, ipNet, err := net.ParseCIDR(ipStr)
if err != nil {
return "", err
}
ip16 := ip.To16()
if ip16 == nil {
return "", fmt.Errorf("cannot convert ip to 16 representation")
}
representation := make([]string, 7)
for i := 2; i <= net.IPv6len-2; i += 2 {
word1 := w.words[ip16[i]]
word2 := w.words[ip16[i+1]]
representation[i/2-1] = fmt.Sprintf("%s-%s", word1, word2)
}
prefixSize, _ := ipNet.Mask.Size()
return strings.Join(representation[:prefixSize/16-1], "."), nil
}
// Convert implements What8Words.
func (w *What8Words) ConvertIdentifier(ipStr string) (string, error) {
ip, err := w.Convert(ipStr)
if err != nil {
return "", err
}
constituents := strings.Split(ip, ".")
return strings.Join(constituents[3:], "."), nil
}
func NewWhat8Words(pathToWords string) (*What8Words, error) {
words, err := ReadWords(pathToWords)
if err != nil {
return nil, err
}
return &What8Words{words: words}, nil
}
// ReadWords reads the what 8 words txt file
func ReadWords(wordFile string) ([]string, error) {
f, err := os.ReadFile(wordFile)
if err != nil {
return nil, err
}
words := make([]string, 257)
reader := bufio.NewScanner(bytes.NewReader(f))
counter := 0
for reader.Scan() && counter <= len(words) {
text := reader.Text()
words[counter] = text
counter++
if reader.Err() != nil {
return nil, reader.Err()
}
}
return words, nil
}

1
smegmesh-web Submodule

Submodule smegmesh-web added at c1128bcd98