Hashing the WireGuard interface

Hashing the interface and using ephmeral ports so that the admin doesn't
choose an interface and port combination. An administrator can alteranatively
decide to provide port but this isn't critical.
This commit is contained in:
Tim Beatham 2023-11-20 13:03:42 +00:00
parent 8f211aa116
commit b179cd3cf4
11 changed files with 61 additions and 64 deletions

View File

@ -16,7 +16,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
type CreateMeshParams struct { type CreateMeshParams struct {
Client *ipcRpc.Client Client *ipcRpc.Client
IfName string
WgPort int WgPort int
Endpoint string Endpoint string
} }
@ -24,7 +23,6 @@ type CreateMeshParams struct {
func createMesh(args *CreateMeshParams) string { func createMesh(args *CreateMeshParams) string {
var reply string var reply string
newMeshParams := ipc.NewMeshArgs{ newMeshParams := ipc.NewMeshArgs{
IfName: args.IfName,
WgPort: args.WgPort, WgPort: args.WgPort,
Endpoint: args.Endpoint, Endpoint: args.Endpoint,
} }
@ -68,7 +66,6 @@ func joinMesh(params *JoinMeshParams) string {
args := ipc.JoinMeshArgs{ args := ipc.JoinMeshArgs{
MeshId: params.MeshId, MeshId: params.MeshId,
IpAdress: params.IpAddress, IpAdress: params.IpAddress,
IfName: params.IfName,
Port: params.WgPort, Port: params.WgPort,
} }
@ -251,14 +248,12 @@ func main() {
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements") deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh") getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{}) var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true}) var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true}) var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true}) var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{}) var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true}) var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
@ -298,7 +293,6 @@ func main() {
if newMeshCmd.Happened() { if newMeshCmd.Happened() {
fmt.Println(createMesh(&CreateMeshParams{ fmt.Println(createMesh(&CreateMeshParams{
Client: client, Client: client,
IfName: *newMeshIfName,
WgPort: *newMeshPort, WgPort: *newMeshPort,
Endpoint: *newMeshEndpoint, Endpoint: *newMeshEndpoint,
})) }))
@ -311,7 +305,6 @@ func main() {
if joinMeshCmd.Happened() { if joinMeshCmd.Happened() {
fmt.Println(joinMesh(&JoinMeshParams{ fmt.Println(joinMesh(&JoinMeshParams{
Client: client, Client: client,
IfName: *joinMeshIfName,
WgPort: *joinMeshPort, WgPort: *joinMeshPort,
IpAddress: *joinMeshIpAddress, IpAddress: *joinMeshIpAddress,
MeshId: *joinMeshId, MeshId: *joinMeshId,

View File

@ -62,11 +62,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
c.JSON(http.StatusBadRequest, &gin.H{ c.JSON(http.StatusBadRequest, &gin.H{
"error": err.Error(), "error": err.Error(),
}) })
return return
} }
ipcRequest := ipc.NewMeshArgs{ ipcRequest := ipc.NewMeshArgs{
IfName: createMesh.IfName,
WgPort: createMesh.WgPort, WgPort: createMesh.WgPort,
} }
@ -100,7 +100,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
ipcRequest := ipc.JoinMeshArgs{ ipcRequest := ipc.JoinMeshArgs{
MeshId: joinMesh.MeshId, MeshId: joinMesh.MeshId,
IpAdress: joinMesh.Bootstrap, IpAdress: joinMesh.Bootstrap,
IfName: joinMesh.IfName,
Port: joinMesh.WgPort, Port: joinMesh.WgPort,
} }

View File

@ -18,13 +18,11 @@ type SmegMesh struct {
} }
type CreateMeshRequest struct { type CreateMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
} }
type JoinMeshRequest struct { type JoinMeshRequest struct {
IfName string `json:"ifName" binding:"required"` WgPort int `json:"port" binding:"gte=1024,lt=65535"`
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
Bootstrap string `json:"bootstrap" binding:"required"` Bootstrap string `json:"bootstrap" binding:"required"`
MeshId string `json:"meshid" binding:"required"` MeshId string `json:"meshid" binding:"required"`
} }

View File

@ -11,8 +11,6 @@ import (
) )
type NewMeshArgs struct { type NewMeshArgs struct {
// IfName is the interface that the mesh instance will run on
IfName string
// WgPort is the WireGuard port to expose // WgPort is the WireGuard port to expose
WgPort int WgPort int
// Endpoint is the routable alias of the machine. Can be an IP // Endpoint is the routable alias of the machine. Can be an IP
@ -25,8 +23,6 @@ type JoinMeshArgs struct {
MeshId string MeshId string
// IpAddress is a routable IP in another mesh // IpAddress is a routable IP in another mesh
IpAdress string IpAdress string
// IfName is the interface name of the mesh
IfName string
// Port is the WireGuard port to expose // Port is the WireGuard port to expose
Port int Port int
// Endpoint is the routable address of this machine. If not provided // Endpoint is the routable address of this machine. If not provided

View File

@ -14,7 +14,7 @@ import (
) )
type MeshManager interface { type MeshManager interface {
CreateMesh(devName string, port int) (string, error) CreateMesh(port int) (string, error)
AddMesh(params *AddMeshParams) error AddMesh(params *AddMeshParams) error
HasChanges(meshid string) bool HasChanges(meshid string) bool
GetMesh(meshId string) MeshProvider GetMesh(meshId string) MeshProvider
@ -115,15 +115,25 @@ func (m *MeshManagerImpl) Prune() error {
} }
// CreateMesh: Creates a new mesh, stores it and returns the mesh id // CreateMesh: Creates a new mesh, stores it and returns the mesh id
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
meshId, err := m.idGenerator.GetId() meshId, err := m.idGenerator.GetId()
var ifName string = ""
if err != nil { if err != nil {
return "", err return "", err
} }
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(port)
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: devName, DevName: ifName,
Port: port, Port: port,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -134,32 +144,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
return "", fmt.Errorf("error creating mesh: %w", err) return "", fmt.Errorf("error creating mesh: %w", err)
} }
if !m.conf.StubWg {
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: devName,
Port: port,
})
if err != nil {
return "", fmt.Errorf("error creating mesh: %w", err)
}
}
m.Meshes[meshId] = nodeManager m.Meshes[meshId] = nodeManager
return meshId, nil return meshId, nil
} }
type AddMeshParams struct { type AddMeshParams struct {
MeshId string MeshId string
DevName string
WgPort int WgPort int
MeshBytes []byte MeshBytes []byte
} }
// AddMesh: Add the mesh to the list of meshes // AddMesh: Add the mesh to the list of meshes
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error { func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
var ifName string
var err error
if !m.conf.StubWg {
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort)
if err != nil {
return err
}
}
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{ meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
DevName: params.DevName, DevName: ifName,
Port: params.WgPort, Port: params.WgPort,
Conf: m.conf, Conf: m.conf,
Client: m.Client, Client: m.Client,
@ -177,14 +186,6 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
} }
m.Meshes[params.MeshId] = meshProvider m.Meshes[params.MeshId] = meshProvider
if !m.conf.StubWg {
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
IfName: params.DevName,
Port: params.WgPort,
})
}
return nil return nil
} }

View File

@ -250,7 +250,7 @@ func NewMeshManagerStub() MeshManager {
return &MeshManagerStub{meshes: make(map[string]MeshProvider)} return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
} }
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) { func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
return "tim123", nil return "tim123", nil
} }

View File

@ -20,7 +20,7 @@ type IpcHandler struct {
} }
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error { func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort) meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
if err != nil { if err != nil {
return err return err
@ -83,7 +83,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{ err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
MeshId: args.MeshId, MeshId: args.MeshId,
DevName: args.IfName,
WgPort: args.Port, WgPort: args.Port,
MeshBytes: meshReply.Mesh, MeshBytes: meshReply.Mesh,
}) })

View File

@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{ err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
MeshId: meshId, MeshId: meshId,
DevName: ifName,
WgPort: port, WgPort: port,
MeshBytes: reply.Mesh, MeshBytes: reply.Mesh,
}) })

View File

@ -2,7 +2,7 @@ package wg
type WgInterfaceManipulatorStub struct{} type WgInterfaceManipulatorStub struct{}
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error { func (i *WgInterfaceManipulatorStub) CreateInterface(port int) error {
return nil return nil
} }

View File

@ -8,14 +8,9 @@ func (m *WgError) Error() string {
return m.msg return m.msg
} }
type CreateInterfaceParams struct {
IfName string
Port int
}
type WgInterfaceManipulator interface { type WgInterfaceManipulator interface {
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
CreateInterface(params *CreateInterfaceParams) error CreateInterface(port int) (string, error)
// AddAddress adds an address to the given interface name // AddAddress adds an address to the given interface name
AddAddress(ifName string, addr string) error AddAddress(ifName string, addr string) error
// RemoveInterface removes the specified interface // RemoveInterface removes the specified interface

View File

@ -1,6 +1,9 @@
package wg package wg
import ( import (
"crypto"
"crypto/rand"
"encoding/base64"
"fmt" "fmt"
"github.com/tim-beatham/wgmesh/pkg/lib" "github.com/tim-beatham/wgmesh/pkg/lib"
@ -13,40 +16,54 @@ type WgInterfaceManipulatorImpl struct {
client *wgctrl.Client client *wgctrl.Client
} }
const hashLength = 6
// CreateInterface creates a WireGuard interface // CreateInterface creates a WireGuard interface
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error { func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
rtnl, err := lib.NewRtNetlinkConfig() rtnl, err := lib.NewRtNetlinkConfig()
if err != nil { if err != nil {
return fmt.Errorf("failed to access link: %w", err) return "", fmt.Errorf("failed to access link: %w", err)
} }
defer rtnl.Close() defer rtnl.Close()
err = rtnl.CreateLink(params.IfName) randomBuf := make([]byte, 32)
_, err = rand.Read(randomBuf)
if err != nil { if err != nil {
return fmt.Errorf("failed to create link: %w", err) return "", err
}
md5 := crypto.MD5.New().Sum(randomBuf)
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
err = rtnl.CreateLink(md5Str)
if err != nil {
return "", fmt.Errorf("failed to create link: %w", err)
} }
privateKey, err := wgtypes.GeneratePrivateKey() privateKey, err := wgtypes.GeneratePrivateKey()
if err != nil { if err != nil {
return fmt.Errorf("failed to create private key: %w", err) return "", fmt.Errorf("failed to create private key: %w", err)
} }
var cfg wgtypes.Config = wgtypes.Config{ var cfg wgtypes.Config = wgtypes.Config{
PrivateKey: &privateKey, PrivateKey: &privateKey,
ListenPort: &params.Port, ListenPort: &port,
} }
err = m.client.ConfigureDevice(params.IfName, cfg) err = m.client.ConfigureDevice(md5Str, cfg)
if err != nil { if err != nil {
return fmt.Errorf("failed to configure dev: %w", err) m.RemoveInterface(md5Str)
return "", fmt.Errorf("failed to configure dev: %w", err)
} }
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName) logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
return nil return md5Str, nil
} }
// Add an address to the given interface // Add an address to the given interface