From a0c20e4d11f3da285c35fc6ed023f8aa0e6ab1ab Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Thu, 28 Sep 2023 16:55:37 +0100 Subject: [PATCH] IPV6 SLAAC --- cmd/wg-mesh/main.go | 6 +- cmd/wgmeshd/main.go | 13 ++-- pkg/cga/cga.go | 126 +++++++++++++++++++++++++++++++ pkg/ctrlserver/ctrlserver.go | 30 -------- pkg/ctrlserver/ipc/ipchandler.go | 29 +++++-- pkg/ctrlserver/rpc/rpchandler.go | 8 +- pkg/log/log.go | 22 ++++++ pkg/slaac/slaac.go | 39 ++++++++++ pkg/wg/wg.go | 4 +- wgmesh.go | 4 - 10 files changed, 223 insertions(+), 58 deletions(-) create mode 100644 pkg/cga/cga.go create mode 100644 pkg/log/log.go create mode 100644 pkg/slaac/slaac.go diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index 5f51297..851939b 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "log" ipcRpc "net/rpc" "os" @@ -29,7 +30,6 @@ func listMeshes(client *ipcRpc.Client) { err := client.Call("Mesh.ListMeshes", "", &reply) if err != nil { - fmt.Println(err.Error()) return } @@ -58,7 +58,7 @@ func getMesh(client *ipcRpc.Client, meshId string) { err := client.Call("Mesh.GetMesh", &meshId, &reply) if err != nil { - fmt.Println(err.Error()) + log.Panic(err.Error()) return } @@ -77,7 +77,7 @@ func enableInterface(client *ipcRpc.Client, meshId string) { err := client.Call("Mesh.EnableInterface", &meshId, &reply) if err != nil { - fmt.Println(err.Error()) + (err.Error()) return } diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index b4f57b7..ccdd19d 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -1,7 +1,7 @@ package main import ( - "fmt" + "log" "net" ctrlserver "github.com/tim-beatham/wgmesh/pkg/ctrlserver" @@ -10,24 +10,25 @@ import ( wg "github.com/tim-beatham/wgmesh/pkg/wg" ) +const ifName = "wgmesh" + func main() { - wgClient, err := wg.CreateClient("wgmesh") + wgClient, err := wg.CreateClient(ifName) if err != nil { - fmt.Println(err) - return + log.Fatalf("Could not create interface %s\n", ifName) } ctrlServer := ctrlserver.NewCtrlServer(wgClient, "wgmesh") - fmt.Println("Running IPC Handler") + log.Println("Running IPC Handler") go ipc.RunIpcHandler(ctrlServer) grpc := rpc.NewRpcServer(ctrlServer) lis, err := net.Listen("tcp", ":8080") if err := grpc.Serve(lis); err != nil { - fmt.Print(err.Error()) + log.Fatal(err.Error()) } defer wgClient.Close() diff --git a/pkg/cga/cga.go b/pkg/cga/cga.go new file mode 100644 index 0000000..dd57c71 --- /dev/null +++ b/pkg/cga/cga.go @@ -0,0 +1,126 @@ +package cga + +/* + * Use a WireGuard public key to generate a unique interface ID + */ + +import ( + "crypto/rand" + "crypto/sha1" + "net" + + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +const ( + ModifierLength = 16 + ZeroLength = 9 + hash2Length = 57 + hash1Length = 58 + Hash2Prefix = 14 + Hash1Prefix = 8 + InterfaceIdLen = 8 + SubnetPrefixLen = 8 +) + +/* + * Cga parameters used to generate an IPV6 interface ID + */ +type CgaParameters struct { + Modifier [ModifierLength]byte + SubnetPrefix [SubnetPrefixLen]byte + CollisionCount uint8 + PublicKey wgtypes.Key + interfaceId [2 * InterfaceIdLen]byte + flag byte +} + +func NewCga(key wgtypes.Key, subnetPrefix [SubnetPrefixLen]byte) (*CgaParameters, error) { + var params CgaParameters + + _, err := rand.Read(params.Modifier[:]) + + if err != nil { + return nil, err + } + + params.PublicKey = key + params.SubnetPrefix = subnetPrefix + return ¶ms, nil +} + +func (c *CgaParameters) generateHash2() []byte { + var byteVal [hash2Length]byte + + for i := 0; i < ModifierLength; i++ { + byteVal[i] = c.Modifier[i] + } + + for i := 0; i < wgtypes.KeyLen; i++ { + byteVal[ModifierLength+ZeroLength+i] = c.PublicKey[i] + } + + hash := sha1.Sum(byteVal[:]) + + return hash[:Hash2Prefix] +} + +func (c *CgaParameters) generateHash1() []byte { + var byteVal [hash1Length]byte + + for i := 0; i < ModifierLength; i++ { + byteVal[i] = c.Modifier[i] + } + + for i := 0; i < wgtypes.KeyLen; i++ { + byteVal[ModifierLength+ZeroLength+i] = c.PublicKey[i] + } + + byteVal[hash1Length-1] = c.CollisionCount + + hash := sha1.Sum(byteVal[:]) + + return hash[:Hash1Prefix] +} + +func clearBit(num, pos int) byte { + mask := ^(1 << pos) + result := num & mask + + return byte(result) +} + +func (c *CgaParameters) generateInterface() []byte { + // TODO: On duplicate address detection increment collision. + // Also incorporate SEC + + hash1 := c.generateHash1() + + var interfaceId []byte = make([]byte, InterfaceIdLen) + + copy(interfaceId[:], hash1) + + interfaceId[0] = clearBit(int(interfaceId[0]), 6) + interfaceId[0] = clearBit(int(interfaceId[1]), 7) + + return interfaceId +} + +func (c *CgaParameters) GetIpv6() net.IP { + if c.flag == 1 { + return c.interfaceId[:] + } + + bytes := c.generateInterface() + + for i := 0; i < InterfaceIdLen; i++ { + c.interfaceId[i] = c.SubnetPrefix[i] + } + + for i := InterfaceIdLen; i < 2*InterfaceIdLen; i++ { + c.interfaceId[i] = bytes[i-8] + } + + c.flag = 1 + return c.interfaceId[:] +} diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index d2dbc8c..b14fe72 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -2,9 +2,7 @@ package ctrlserver import ( "errors" - "fmt" "net" - "strconv" "github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/wg" @@ -33,20 +31,6 @@ func (server *MeshCtrlServer) IsInMesh(meshId string) bool { return inMesh } -func (server *MeshCtrlServer) addSelfToMesh(meshId string) error { - ipAddr := lib.GetOutboundIP() - - node := MeshNode{ - HostEndpoint: ipAddr.String() + ":8080", - PublicKey: server.GetDevice().PublicKey.String(), - WgEndpoint: ipAddr.String() + ":51820", - WgHost: "10.0.0.1/32", - } - - server.Meshes[meshId].Nodes[node.HostEndpoint] = node - return nil -} - func (server *MeshCtrlServer) CreateMesh() (*Mesh, error) { key, err := wgtypes.GenerateKey() @@ -60,7 +44,6 @@ func (server *MeshCtrlServer) CreateMesh() (*Mesh, error) { } server.Meshes[key.String()] = mesh - server.addSelfToMesh(mesh.SharedKey.String()) return &mesh, nil } @@ -96,8 +79,6 @@ func (server *MeshCtrlServer) AddHost(args AddHostArgs) error { if err == nil { nodes.Nodes[args.HostEndpoint] = node - } else { - fmt.Println(err.Error()) } return err @@ -117,8 +98,6 @@ func AddWgPeer(ifName string, client *wgctrl.Client, node MeshNode) error { peer := make([]wgtypes.PeerConfig, 1) peerPublic, err := wgtypes.ParseKey(node.PublicKey) - fmt.Println("node.PublicKey: " + node.PublicKey) - fmt.Println("peerPublic: " + peerPublic.String()) if err != nil { return err @@ -127,7 +106,6 @@ func AddWgPeer(ifName string, client *wgctrl.Client, node MeshNode) error { peerEndpoint, err := net.ResolveUDPAddr("udp", node.WgEndpoint) if err != nil { - fmt.Println("err") return err } @@ -152,14 +130,6 @@ func AddWgPeer(ifName string, client *wgctrl.Client, node MeshNode) error { err = client.ConfigureDevice(ifName, cfg) - if err != nil { - fmt.Println(err.Error()) - } - - dev, err := client.Device(ifName) - - fmt.Println("Number of peers: " + strconv.Itoa(len(dev.Peers))) - if err != nil { return err } diff --git a/pkg/ctrlserver/ipc/ipchandler.go b/pkg/ctrlserver/ipc/ipchandler.go index b3ea268..16d6536 100644 --- a/pkg/ctrlserver/ipc/ipchandler.go +++ b/pkg/ctrlserver/ipc/ipchandler.go @@ -3,7 +3,6 @@ package ipc import ( "context" "errors" - "fmt" "net" "net/http" ipcRpc "net/rpc" @@ -16,6 +15,9 @@ import ( "github.com/tim-beatham/wgmesh/pkg/ctrlserver/rpc" "github.com/tim-beatham/wgmesh/pkg/ipc" ipctypes "github.com/tim-beatham/wgmesh/pkg/ipc" + "github.com/tim-beatham/wgmesh/pkg/lib" + logging "github.com/tim-beatham/wgmesh/pkg/log" + "github.com/tim-beatham/wgmesh/pkg/slaac" "github.com/tim-beatham/wgmesh/pkg/wg" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" @@ -28,13 +30,28 @@ type Mesh struct { Server *ctrlserver.MeshCtrlServer } +const MeshIfName = "wgmesh" + /* * Create a new WireGuard mesh network */ func (n Mesh) CreateNewMesh(name *string, reply *string) error { - wg.CreateInterface("wgmesh") + wg.CreateInterface(MeshIfName) mesh, err := n.Server.CreateMesh() + ula, _ := slaac.NewULA(n.Server.GetDevice().PublicKey, "0") + + outBoundIp := lib.GetOutboundIP().String() + + addHostArgs := ctrlserver.AddHostArgs{ + HostEndpoint: outBoundIp + ":8080", + PublicKey: n.Server.GetDevice().PublicKey.String(), + WgEndpoint: outBoundIp + ":51820", + WgIp: ula.CGA.GetIpv6().String() + "/128", + MeshId: mesh.SharedKey.String(), + } + + n.Server.AddHost(addHostArgs) if err != nil { return err @@ -113,6 +130,7 @@ func updatePeer(n *Mesh, node ctrlserver.MeshNode, wgHost string, meshId string) defer cancel() dev := n.Server.GetDevice() + joinMeshReq := rpc.JoinMeshRequest{ MeshId: meshId, HostPort: 8080, @@ -143,7 +161,7 @@ func updatePeers(n *Mesh, meshId string, wgHost string, nodesToExclude []string) err := updatePeer(n, node, wgHost, meshId) if err != nil { - fmt.Println(err.Error()) + return err } } } @@ -166,14 +184,16 @@ func (n Mesh) JoinMesh(args *ipctypes.JoinMeshArgs, reply *string) error { defer cancel() dev := n.Server.GetDevice() + ula, _ := slaac.NewULA(dev.PublicKey, "0") - fmt.Print("Pub Key:" + dev.PublicKey.String()) + logging.InfoLog.Println("WgIP: " + ula.CGA.GetIpv6().String()) joinMeshReq := rpc.JoinMeshRequest{ MeshId: args.MeshId, HostPort: 8080, PublicKey: dev.PublicKey.String(), WgPort: int32(dev.ListenPort), + WgIp: ula.CGA.GetIpv6().String() + "/128", } r, err := c.JoinMesh(ctx, &joinMeshReq) @@ -200,7 +220,6 @@ func (n Mesh) GetMesh(meshId string, reply *ipc.GetMeshReply) error { i := 0 for _, n := range mesh.Nodes { - fmt.Println(n.PublicKey) nodes[i] = n i += 1 } diff --git a/pkg/ctrlserver/rpc/rpchandler.go b/pkg/ctrlserver/rpc/rpchandler.go index ba5b53f..be0297b 100644 --- a/pkg/ctrlserver/rpc/rpchandler.go +++ b/pkg/ctrlserver/rpc/rpchandler.go @@ -3,8 +3,6 @@ package rpc import ( context "context" "errors" - "fmt" - "math/rand" "net" "strconv" @@ -58,7 +56,6 @@ func (m *meshCtrlServer) GetMesh(ctx context.Context, request *GetMeshRequest) ( func (m *meshCtrlServer) JoinMesh(ctx context.Context, request *JoinMeshRequest) (*JoinMeshReply, error) { p, _ := peer.FromContext(ctx) - fmt.Println(p.Addr.String()) hostIp, _, err := net.SplitHostPort(p.Addr.String()) @@ -69,12 +66,9 @@ func (m *meshCtrlServer) JoinMesh(ctx context.Context, request *JoinMeshRequest) wgIp := request.WgIp if wgIp == "" { - wgIp = "10.0.0." + strconv.Itoa(rand.Intn(253)+1) + "/32" + return nil, errors.New("Haven't provided a valid IP address") } - fmt.Println("Join server public key: " + request.PublicKey) - fmt.Println("Request: " + request.MeshId) - addHostArgs := ctrlserver.AddHostArgs{ HostEndpoint: hostIp + ":" + strconv.Itoa(int(request.HostPort)), PublicKey: request.PublicKey, diff --git a/pkg/log/log.go b/pkg/log/log.go new file mode 100644 index 0000000..56310ae --- /dev/null +++ b/pkg/log/log.go @@ -0,0 +1,22 @@ +package logging + +/* + * This package creates the info, warning and error loggers. + */ + +import ( + "log" + "os" +) + +var ( + InfoLog *log.Logger + WarningLog *log.Logger + ErrorLog *log.Logger +) + +func init() { + InfoLog = log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) + WarningLog = log.New(os.Stdout, "WARNING: ", log.Ldate|log.Ltime|log.Lshortfile) + ErrorLog = log.New(os.Stderr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) +} diff --git a/pkg/slaac/slaac.go b/pkg/slaac/slaac.go new file mode 100644 index 0000000..5eea8ef --- /dev/null +++ b/pkg/slaac/slaac.go @@ -0,0 +1,39 @@ +package slaac + +import ( + "crypto/sha1" + + "github.com/tim-beatham/wgmesh/pkg/cga" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type ULA struct { + CGA cga.CgaParameters +} + +func getULAPrefix(meshId string) [8]byte { + var ulaPrefix [8]byte + + ulaPrefix[0] = 0xfd + + s := sha1.Sum([]byte(meshId)) + + for i := 1; i < 7; i++ { + ulaPrefix[i] = s[i-1] + } + + ulaPrefix[7] = 1 + return ulaPrefix +} + +func NewULA(key wgtypes.Key, meshId string) (*ULA, error) { + ulaPrefix := getULAPrefix(meshId) + + c, err := cga.NewCga(key, ulaPrefix) + + if err != nil { + return nil, err + } + + return &ULA{CGA: *c}, nil +} diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index 2446d5f..ea17373 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -20,7 +20,6 @@ func CreateInterface(ifName string) error { cmd := exec.Command("/usr/bin/ip", "link", "add", "dev", ifName, "type", "wireguard") if err := cmd.Run(); err != nil { - fmt.Println(err.Error()) return err } } @@ -74,10 +73,9 @@ func EnableInterface(ifName string, ip string) error { return err } - cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/24", "dev", "wgmesh") + cmd = exec.Command("/usr/bin/ip", "addr", "add", hostIp.String()+"/64", "dev", "wgmesh") if err := cmd.Run(); err != nil { - fmt.Println(err.Error()) return err } diff --git a/wgmesh.go b/wgmesh.go index 012f558..916be71 100644 --- a/wgmesh.go +++ b/wgmesh.go @@ -11,7 +11,6 @@ func main() { client, err := wgctrl.New() if err != nil { - fmt.Println("Error creating device") return } @@ -19,7 +18,6 @@ func main() { var listenPort int = 5109 if err != nil { - fmt.Println("Error creating private key") return } @@ -31,14 +29,12 @@ func main() { err = client.ConfigureDevice("utun9", cfg) if err != nil { - fmt.Println(err.Error()) return } devices, err := client.Devices() if err != nil { - fmt.Println("unable to retrieve devices") return }