mirror of
https://github.com/tim-beatham/smegmesh.git
synced 2025-01-22 13:38:35 +01:00
Hashing the WireGuard interface
Hashing the interface and using ephmeral ports so that the admin doesn't choose an interface and port combination. An administrator can alteranatively decide to provide port but this isn't critical.
This commit is contained in:
parent
8f211aa116
commit
b179cd3cf4
@ -16,7 +16,6 @@ const SockAddr = "/tmp/wgmesh_ipc.sock"
|
|||||||
|
|
||||||
type CreateMeshParams struct {
|
type CreateMeshParams struct {
|
||||||
Client *ipcRpc.Client
|
Client *ipcRpc.Client
|
||||||
IfName string
|
|
||||||
WgPort int
|
WgPort int
|
||||||
Endpoint string
|
Endpoint string
|
||||||
}
|
}
|
||||||
@ -24,7 +23,6 @@ type CreateMeshParams struct {
|
|||||||
func createMesh(args *CreateMeshParams) string {
|
func createMesh(args *CreateMeshParams) string {
|
||||||
var reply string
|
var reply string
|
||||||
newMeshParams := ipc.NewMeshArgs{
|
newMeshParams := ipc.NewMeshArgs{
|
||||||
IfName: args.IfName,
|
|
||||||
WgPort: args.WgPort,
|
WgPort: args.WgPort,
|
||||||
Endpoint: args.Endpoint,
|
Endpoint: args.Endpoint,
|
||||||
}
|
}
|
||||||
@ -68,7 +66,6 @@ func joinMesh(params *JoinMeshParams) string {
|
|||||||
args := ipc.JoinMeshArgs{
|
args := ipc.JoinMeshArgs{
|
||||||
MeshId: params.MeshId,
|
MeshId: params.MeshId,
|
||||||
IpAdress: params.IpAddress,
|
IpAdress: params.IpAddress,
|
||||||
IfName: params.IfName,
|
|
||||||
Port: params.WgPort,
|
Port: params.WgPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,14 +248,12 @@ func main() {
|
|||||||
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
|
deleteServiceCmd := parser.NewCommand("delete-service", "Remove a service from your advertisements")
|
||||||
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
|
getNodeCmd := parser.NewCommand("get-node", "Get a specific node from the mesh")
|
||||||
|
|
||||||
var newMeshIfName *string = newMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
|
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{})
|
||||||
var newMeshPort *int = newMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
|
|
||||||
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
|
var newMeshEndpoint *string = newMeshCmd.String("e", "endpoint", &argparse.Options{})
|
||||||
|
|
||||||
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
|
var joinMeshId *string = joinMeshCmd.String("m", "mesh", &argparse.Options{Required: true})
|
||||||
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
|
var joinMeshIpAddress *string = joinMeshCmd.String("i", "ip", &argparse.Options{Required: true})
|
||||||
var joinMeshIfName *string = joinMeshCmd.String("f", "ifname", &argparse.Options{Required: true})
|
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{})
|
||||||
var joinMeshPort *int = joinMeshCmd.Int("p", "wgport", &argparse.Options{Required: true})
|
|
||||||
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
|
var joinMeshEndpoint *string = joinMeshCmd.String("e", "endpoint", &argparse.Options{})
|
||||||
|
|
||||||
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
|
var enableInterfaceMeshId *string = enableInterfaceCmd.String("m", "mesh", &argparse.Options{Required: true})
|
||||||
@ -298,7 +293,6 @@ func main() {
|
|||||||
if newMeshCmd.Happened() {
|
if newMeshCmd.Happened() {
|
||||||
fmt.Println(createMesh(&CreateMeshParams{
|
fmt.Println(createMesh(&CreateMeshParams{
|
||||||
Client: client,
|
Client: client,
|
||||||
IfName: *newMeshIfName,
|
|
||||||
WgPort: *newMeshPort,
|
WgPort: *newMeshPort,
|
||||||
Endpoint: *newMeshEndpoint,
|
Endpoint: *newMeshEndpoint,
|
||||||
}))
|
}))
|
||||||
@ -311,7 +305,6 @@ func main() {
|
|||||||
if joinMeshCmd.Happened() {
|
if joinMeshCmd.Happened() {
|
||||||
fmt.Println(joinMesh(&JoinMeshParams{
|
fmt.Println(joinMesh(&JoinMeshParams{
|
||||||
Client: client,
|
Client: client,
|
||||||
IfName: *joinMeshIfName,
|
|
||||||
WgPort: *joinMeshPort,
|
WgPort: *joinMeshPort,
|
||||||
IpAddress: *joinMeshIpAddress,
|
IpAddress: *joinMeshIpAddress,
|
||||||
MeshId: *joinMeshId,
|
MeshId: *joinMeshId,
|
||||||
|
@ -62,11 +62,11 @@ func (s *SmegServer) CreateMesh(c *gin.Context) {
|
|||||||
c.JSON(http.StatusBadRequest, &gin.H{
|
c.JSON(http.StatusBadRequest, &gin.H{
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipcRequest := ipc.NewMeshArgs{
|
ipcRequest := ipc.NewMeshArgs{
|
||||||
IfName: createMesh.IfName,
|
|
||||||
WgPort: createMesh.WgPort,
|
WgPort: createMesh.WgPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -100,7 +100,6 @@ func (s *SmegServer) JoinMesh(c *gin.Context) {
|
|||||||
ipcRequest := ipc.JoinMeshArgs{
|
ipcRequest := ipc.JoinMeshArgs{
|
||||||
MeshId: joinMesh.MeshId,
|
MeshId: joinMesh.MeshId,
|
||||||
IpAdress: joinMesh.Bootstrap,
|
IpAdress: joinMesh.Bootstrap,
|
||||||
IfName: joinMesh.IfName,
|
|
||||||
Port: joinMesh.WgPort,
|
Port: joinMesh.WgPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,13 +18,11 @@ type SmegMesh struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CreateMeshRequest struct {
|
type CreateMeshRequest struct {
|
||||||
IfName string `json:"ifName" binding:"required"`
|
WgPort int `json:"port" binding:"gte=1024,lt=65535"`
|
||||||
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type JoinMeshRequest struct {
|
type JoinMeshRequest struct {
|
||||||
IfName string `json:"ifName" binding:"required"`
|
WgPort int `json:"port" binding:"gte=1024,lt=65535"`
|
||||||
WgPort int `json:"port" binding:"required,gte=1024,lt=65535"`
|
|
||||||
Bootstrap string `json:"bootstrap" binding:"required"`
|
Bootstrap string `json:"bootstrap" binding:"required"`
|
||||||
MeshId string `json:"meshid" binding:"required"`
|
MeshId string `json:"meshid" binding:"required"`
|
||||||
}
|
}
|
||||||
|
@ -11,8 +11,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type NewMeshArgs struct {
|
type NewMeshArgs struct {
|
||||||
// IfName is the interface that the mesh instance will run on
|
|
||||||
IfName string
|
|
||||||
// WgPort is the WireGuard port to expose
|
// WgPort is the WireGuard port to expose
|
||||||
WgPort int
|
WgPort int
|
||||||
// Endpoint is the routable alias of the machine. Can be an IP
|
// Endpoint is the routable alias of the machine. Can be an IP
|
||||||
@ -25,8 +23,6 @@ type JoinMeshArgs struct {
|
|||||||
MeshId string
|
MeshId string
|
||||||
// IpAddress is a routable IP in another mesh
|
// IpAddress is a routable IP in another mesh
|
||||||
IpAdress string
|
IpAdress string
|
||||||
// IfName is the interface name of the mesh
|
|
||||||
IfName string
|
|
||||||
// Port is the WireGuard port to expose
|
// Port is the WireGuard port to expose
|
||||||
Port int
|
Port int
|
||||||
// Endpoint is the routable address of this machine. If not provided
|
// Endpoint is the routable address of this machine. If not provided
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type MeshManager interface {
|
type MeshManager interface {
|
||||||
CreateMesh(devName string, port int) (string, error)
|
CreateMesh(port int) (string, error)
|
||||||
AddMesh(params *AddMeshParams) error
|
AddMesh(params *AddMeshParams) error
|
||||||
HasChanges(meshid string) bool
|
HasChanges(meshid string) bool
|
||||||
GetMesh(meshId string) MeshProvider
|
GetMesh(meshId string) MeshProvider
|
||||||
@ -115,15 +115,25 @@ func (m *MeshManagerImpl) Prune() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
|
// CreateMesh: Creates a new mesh, stores it and returns the mesh id
|
||||||
func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
|
func (m *MeshManagerImpl) CreateMesh(port int) (string, error) {
|
||||||
meshId, err := m.idGenerator.GetId()
|
meshId, err := m.idGenerator.GetId()
|
||||||
|
|
||||||
|
var ifName string = ""
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !m.conf.StubWg {
|
||||||
|
ifName, err = m.interfaceManipulator.CreateInterface(port)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error creating mesh: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
nodeManager, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
||||||
DevName: devName,
|
DevName: ifName,
|
||||||
Port: port,
|
Port: port,
|
||||||
Conf: m.conf,
|
Conf: m.conf,
|
||||||
Client: m.Client,
|
Client: m.Client,
|
||||||
@ -134,32 +144,31 @@ func (m *MeshManagerImpl) CreateMesh(devName string, port int) (string, error) {
|
|||||||
return "", fmt.Errorf("error creating mesh: %w", err)
|
return "", fmt.Errorf("error creating mesh: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.conf.StubWg {
|
|
||||||
err = m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
|
|
||||||
IfName: devName,
|
|
||||||
Port: port,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("error creating mesh: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Meshes[meshId] = nodeManager
|
m.Meshes[meshId] = nodeManager
|
||||||
return meshId, nil
|
return meshId, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type AddMeshParams struct {
|
type AddMeshParams struct {
|
||||||
MeshId string
|
MeshId string
|
||||||
DevName string
|
|
||||||
WgPort int
|
WgPort int
|
||||||
MeshBytes []byte
|
MeshBytes []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddMesh: Add the mesh to the list of meshes
|
// AddMesh: Add the mesh to the list of meshes
|
||||||
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
|
func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
|
||||||
|
var ifName string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if !m.conf.StubWg {
|
||||||
|
ifName, err = m.interfaceManipulator.CreateInterface(params.WgPort)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
meshProvider, err := m.meshProviderFactory.CreateMesh(&MeshProviderFactoryParams{
|
||||||
DevName: params.DevName,
|
DevName: ifName,
|
||||||
Port: params.WgPort,
|
Port: params.WgPort,
|
||||||
Conf: m.conf,
|
Conf: m.conf,
|
||||||
Client: m.Client,
|
Client: m.Client,
|
||||||
@ -177,14 +186,6 @@ func (m *MeshManagerImpl) AddMesh(params *AddMeshParams) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.Meshes[params.MeshId] = meshProvider
|
m.Meshes[params.MeshId] = meshProvider
|
||||||
|
|
||||||
if !m.conf.StubWg {
|
|
||||||
return m.interfaceManipulator.CreateInterface(&wg.CreateInterfaceParams{
|
|
||||||
IfName: params.DevName,
|
|
||||||
Port: params.WgPort,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,7 +250,7 @@ func NewMeshManagerStub() MeshManager {
|
|||||||
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
|
return &MeshManagerStub{meshes: make(map[string]MeshProvider)}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MeshManagerStub) CreateMesh(devName string, port int) (string, error) {
|
func (m *MeshManagerStub) CreateMesh(port int) (string, error) {
|
||||||
return "tim123", nil
|
return "tim123", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ type IpcHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
func (n *IpcHandler) CreateMesh(args *ipc.NewMeshArgs, reply *string) error {
|
||||||
meshId, err := n.Server.GetMeshManager().CreateMesh(args.IfName, args.WgPort)
|
meshId, err := n.Server.GetMeshManager().CreateMesh(args.WgPort)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -83,7 +83,6 @@ func (n *IpcHandler) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
|||||||
|
|
||||||
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
|
err = n.Server.GetMeshManager().AddMesh(&mesh.AddMeshParams{
|
||||||
MeshId: args.MeshId,
|
MeshId: args.MeshId,
|
||||||
DevName: args.IfName,
|
|
||||||
WgPort: args.Port,
|
WgPort: args.Port,
|
||||||
MeshBytes: meshReply.Mesh,
|
MeshBytes: meshReply.Mesh,
|
||||||
})
|
})
|
||||||
|
@ -50,7 +50,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, ifName string, port int, endP
|
|||||||
|
|
||||||
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
|
err = s.server.MeshManager.AddMesh(&mesh.AddMeshParams{
|
||||||
MeshId: meshId,
|
MeshId: meshId,
|
||||||
DevName: ifName,
|
|
||||||
WgPort: port,
|
WgPort: port,
|
||||||
MeshBytes: reply.Mesh,
|
MeshBytes: reply.Mesh,
|
||||||
})
|
})
|
||||||
|
@ -2,7 +2,7 @@ package wg
|
|||||||
|
|
||||||
type WgInterfaceManipulatorStub struct{}
|
type WgInterfaceManipulatorStub struct{}
|
||||||
|
|
||||||
func (i *WgInterfaceManipulatorStub) CreateInterface(params *CreateInterfaceParams) error {
|
func (i *WgInterfaceManipulatorStub) CreateInterface(port int) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -8,14 +8,9 @@ func (m *WgError) Error() string {
|
|||||||
return m.msg
|
return m.msg
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateInterfaceParams struct {
|
|
||||||
IfName string
|
|
||||||
Port int
|
|
||||||
}
|
|
||||||
|
|
||||||
type WgInterfaceManipulator interface {
|
type WgInterfaceManipulator interface {
|
||||||
// CreateInterface creates a WireGuard interface
|
// CreateInterface creates a WireGuard interface
|
||||||
CreateInterface(params *CreateInterfaceParams) error
|
CreateInterface(port int) (string, error)
|
||||||
// AddAddress adds an address to the given interface name
|
// AddAddress adds an address to the given interface name
|
||||||
AddAddress(ifName string, addr string) error
|
AddAddress(ifName string, addr string) error
|
||||||
// RemoveInterface removes the specified interface
|
// RemoveInterface removes the specified interface
|
||||||
|
37
pkg/wg/wg.go
37
pkg/wg/wg.go
@ -1,6 +1,9 @@
|
|||||||
package wg
|
package wg
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
"github.com/tim-beatham/wgmesh/pkg/lib"
|
||||||
@ -13,40 +16,54 @@ type WgInterfaceManipulatorImpl struct {
|
|||||||
client *wgctrl.Client
|
client *wgctrl.Client
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const hashLength = 6
|
||||||
|
|
||||||
// CreateInterface creates a WireGuard interface
|
// CreateInterface creates a WireGuard interface
|
||||||
func (m *WgInterfaceManipulatorImpl) CreateInterface(params *CreateInterfaceParams) error {
|
func (m *WgInterfaceManipulatorImpl) CreateInterface(port int) (string, error) {
|
||||||
rtnl, err := lib.NewRtNetlinkConfig()
|
rtnl, err := lib.NewRtNetlinkConfig()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to access link: %w", err)
|
return "", fmt.Errorf("failed to access link: %w", err)
|
||||||
}
|
}
|
||||||
defer rtnl.Close()
|
defer rtnl.Close()
|
||||||
|
|
||||||
err = rtnl.CreateLink(params.IfName)
|
randomBuf := make([]byte, 32)
|
||||||
|
_, err = rand.Read(randomBuf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create link: %w", err)
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
md5 := crypto.MD5.New().Sum(randomBuf)
|
||||||
|
|
||||||
|
md5Str := fmt.Sprintf("wg%s", base64.StdEncoding.EncodeToString(md5)[:hashLength])
|
||||||
|
|
||||||
|
err = rtnl.CreateLink(md5Str)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create link: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
privateKey, err := wgtypes.GeneratePrivateKey()
|
privateKey, err := wgtypes.GeneratePrivateKey()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create private key: %w", err)
|
return "", fmt.Errorf("failed to create private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var cfg wgtypes.Config = wgtypes.Config{
|
var cfg wgtypes.Config = wgtypes.Config{
|
||||||
PrivateKey: &privateKey,
|
PrivateKey: &privateKey,
|
||||||
ListenPort: ¶ms.Port,
|
ListenPort: &port,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.client.ConfigureDevice(params.IfName, cfg)
|
err = m.client.ConfigureDevice(md5Str, cfg)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to configure dev: %w", err)
|
m.RemoveInterface(md5Str)
|
||||||
|
return "", fmt.Errorf("failed to configure dev: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.Log.WriteInfof("ip link set up dev %s type wireguard", params.IfName)
|
logging.Log.WriteInfof("ip link set up dev %s type wireguard", md5Str)
|
||||||
return nil
|
return md5Str, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add an address to the given interface
|
// Add an address to the given interface
|
||||||
|
Loading…
Reference in New Issue
Block a user