forked from extern/smegmesh
Tested with large number of nodes
This commit is contained in:
parent
ef2b57047d
commit
8e89281484
@ -31,7 +31,7 @@ func listMeshes(client *ipcRpc.Client) {
|
|||||||
err := client.Call("RobinIpc.ListMeshes", "", &reply)
|
err := client.Call("RobinIpc.ListMeshes", "", &reply)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Println(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,7 +36,6 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
|
ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams)
|
||||||
authProvider.Manager = ctrlServer.ConnectionServer.JwtManager
|
|
||||||
syncProvider.Server = ctrlServer
|
syncProvider.Server = ctrlServer
|
||||||
syncRequester := sync.NewSyncRequester(ctrlServer)
|
syncRequester := sync.NewSyncRequester(ctrlServer)
|
||||||
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2)
|
syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2)
|
||||||
@ -50,7 +49,8 @@ func main() {
|
|||||||
robinIpc = robin.NewRobinIpc(robinIpcParams)
|
robinIpc = robin.NewRobinIpc(robinIpcParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Fatalln(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Running IPC Handler")
|
log.Println("Running IPC Handler")
|
||||||
@ -61,9 +61,12 @@ func main() {
|
|||||||
err = ctrlServer.ConnectionServer.Listen()
|
err = ctrlServer.ConnectionServer.Listen()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Fatalln(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
defer wgClient.Close()
|
|
||||||
defer syncScheduler.Stop()
|
defer syncScheduler.Stop()
|
||||||
|
defer ctrlServer.Close()
|
||||||
|
defer wgClient.Close()
|
||||||
}
|
}
|
||||||
|
140
pkg/auth/jwt.go
140
pkg/auth/jwt.go
@ -1,140 +0,0 @@
|
|||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
)
|
|
||||||
|
|
||||||
// JwtMesh contains all the sessions with the mesh network
|
|
||||||
type JwtMesh struct {
|
|
||||||
meshId string
|
|
||||||
// nodes contains a set of nodes with the string being the jwt token
|
|
||||||
nodes map[string]interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// JwtManager manages jwt tokens indicating a session
|
|
||||||
// between this host and another within a specific mesh
|
|
||||||
type JwtManager struct {
|
|
||||||
secretKey []byte
|
|
||||||
tokenDuration time.Duration
|
|
||||||
// meshes contains all the meshes that we have sessions with
|
|
||||||
meshes map[string]*JwtMesh
|
|
||||||
}
|
|
||||||
|
|
||||||
// JwtNode represents a jwt node in the mesh network
|
|
||||||
type JwtNode struct {
|
|
||||||
MeshId string `json:"meshId"`
|
|
||||||
Alias string `json:"alias"`
|
|
||||||
jwt.RegisteredClaims
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewJwtManager(secretKey string, tokenDuration time.Duration) *JwtManager {
|
|
||||||
meshes := make(map[string]*JwtMesh)
|
|
||||||
return &JwtManager{[]byte(secretKey), tokenDuration, meshes}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) {
|
|
||||||
logging.InfoLog.Println("MeshID: " + meshId)
|
|
||||||
logging.InfoLog.Println("Token Duration: " + strconv.Itoa(int(m.tokenDuration)))
|
|
||||||
node := JwtNode{
|
|
||||||
MeshId: meshId,
|
|
||||||
Alias: alias,
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(m.tokenDuration)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
mesh, contains := m.meshes[meshId]
|
|
||||||
|
|
||||||
if !contains {
|
|
||||||
mesh = new(JwtMesh)
|
|
||||||
mesh.meshId = meshId
|
|
||||||
mesh.nodes = make(map[string]interface{})
|
|
||||||
mesh.nodes[meshId] = mesh
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, node)
|
|
||||||
signedString, err := token.SignedString(m.secretKey)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err.Error())
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, exists := mesh.nodes[signedString]
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
return nil, errors.New("Node already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
mesh.nodes[signedString] = struct{}{}
|
|
||||||
return &signedString, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JwtManager) Verify(accessToken string) (*JwtNode, bool) {
|
|
||||||
token, err := jwt.ParseWithClaims(accessToken, &JwtNode{}, func(t *jwt.Token) (interface{}, error) {
|
|
||||||
return m.secretKey, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
|
|
||||||
if !token.Valid {
|
|
||||||
return nil, token.Valid
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, ok := token.Claims.(*JwtNode)
|
|
||||||
return claims, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor {
|
|
||||||
return func(
|
|
||||||
ctx context.Context,
|
|
||||||
req interface{},
|
|
||||||
info *grpc.UnaryServerInfo,
|
|
||||||
handler grpc.UnaryHandler,
|
|
||||||
) (interface{}, error) {
|
|
||||||
|
|
||||||
if strings.Contains(info.FullMethod, "") {
|
|
||||||
return handler(ctx, req)
|
|
||||||
}
|
|
||||||
|
|
||||||
md, ok := metadata.FromIncomingContext(ctx)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "metadata is not provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
values := md["authorization"]
|
|
||||||
|
|
||||||
for _, w := range values {
|
|
||||||
logging.InfoLog.Printf(w)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(values) == 0 {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "authorization token is not provided")
|
|
||||||
}
|
|
||||||
|
|
||||||
acessToken := values[0]
|
|
||||||
|
|
||||||
_, valid := m.Verify(acessToken)
|
|
||||||
|
|
||||||
if !valid {
|
|
||||||
return nil, status.Errorf(codes.Unauthenticated, "Invalid access token: %s", acessToken)
|
|
||||||
}
|
|
||||||
|
|
||||||
return handler(ctx, req)
|
|
||||||
}
|
|
||||||
}
|
|
@ -193,14 +193,13 @@ func (m *CrdtNodeManager) Length() int {
|
|||||||
return m.doc.Path("nodes").Map().Len()
|
return m.doc.Path("nodes").Map().Len()
|
||||||
}
|
}
|
||||||
|
|
||||||
const threshold = 2
|
|
||||||
const thresholdVotes = 0.1
|
const thresholdVotes = 0.1
|
||||||
|
|
||||||
func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
|
func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
|
||||||
node, err := m.GetNode(endpoint)
|
node, err := m.GetNode(endpoint)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.InfoLog.Printf("Cannot get node node: %s\n", endpoint)
|
logging.Log.WriteErrorf("Cannot get node node: %s\n", endpoint)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,14 +214,12 @@ func (m *CrdtNodeManager) HasFailed(endpoint string) bool {
|
|||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
count := value.Int64()
|
count := value.Int64()
|
||||||
|
|
||||||
if count >= threshold {
|
if count >= 1 {
|
||||||
countFailed++
|
countFailed++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.InfoLog.Printf("Count Failed Value: %d\n", countFailed)
|
return countFailed >= 4
|
||||||
logging.InfoLog.Printf("Threshold Value: %d\n", int(thresholdVotes*float64(m.Length())+1))
|
|
||||||
return countFailed >= int(thresholdVotes*float64(m.Length())+1)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
|
func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error {
|
||||||
@ -232,7 +229,6 @@ func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNode
|
|||||||
|
|
||||||
for _, n := range nodes {
|
for _, n := range nodes {
|
||||||
peer, err := m.convertMeshNode(n)
|
peer, err := m.convertMeshNode(n)
|
||||||
logging.InfoLog.Println(n.HostEndpoint)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -24,14 +24,14 @@ func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) {
|
|||||||
yamlBytes, err := os.ReadFile(filePath)
|
yamlBytes, err := os.ReadFile(filePath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Printf("Read file error: %s\n", err.Error())
|
logging.Log.WriteErrorf("Read file error: %s\n", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = yaml.Unmarshal(yamlBytes, &conf)
|
err = yaml.Unmarshal(yamlBytes, &conf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Printf("Unmarshal error: %s\n", err.Error())
|
logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error())
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
116
pkg/conn/conn.go
116
pkg/conn/conn.go
@ -1,116 +0,0 @@
|
|||||||
// conn manages gRPC connections between peers.
|
|
||||||
// Includes timers.
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/lib"
|
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
"google.golang.org/grpc/credentials"
|
|
||||||
"google.golang.org/grpc/metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PeerConnection interfacing for a secure connection between
|
|
||||||
// two peers.
|
|
||||||
type PeerConnection interface {
|
|
||||||
Connect() error
|
|
||||||
Close() error
|
|
||||||
Authenticate(meshId string) error
|
|
||||||
GetClient() (*grpc.ClientConn, error)
|
|
||||||
CreateAuthContext(meshId string) (context.Context, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WgCtrlConnection struct {
|
|
||||||
clientConfig *tls.Config
|
|
||||||
conn *grpc.ClientConn
|
|
||||||
endpoint string
|
|
||||||
// tokens maps a meshID to the corresponding token
|
|
||||||
tokens map[string]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
|
|
||||||
var conn WgCtrlConnection
|
|
||||||
conn.tokens = make(map[string]string)
|
|
||||||
conn.clientConfig = clientConfig
|
|
||||||
conn.endpoint = server
|
|
||||||
return &conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *WgCtrlConnection) Authenticate(meshId string) error {
|
|
||||||
conn, err := grpc.Dial(c.endpoint,
|
|
||||||
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)))
|
|
||||||
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
|
|
||||||
client := rpc.NewAuthenticationClient(conn)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
authRequest := rpc.JoinAuthMeshRequest{
|
|
||||||
MeshId: meshId,
|
|
||||||
Alias: lib.GetOutboundIP().String(),
|
|
||||||
}
|
|
||||||
|
|
||||||
reply, err := client.JoinMesh(ctx, &authRequest)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.tokens[meshId] = *reply.Token
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
|
|
||||||
func (c *WgCtrlConnection) Connect() error {
|
|
||||||
conn, err := grpc.Dial(c.endpoint,
|
|
||||||
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logging.ErrorLog.Printf("Could not connect: %s\n", err.Error())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.conn = conn
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close: Closes the client connections
|
|
||||||
func (c *WgCtrlConnection) Close() error {
|
|
||||||
return c.conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClient: Gets the client connection
|
|
||||||
func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
|
|
||||||
var err error = nil
|
|
||||||
|
|
||||||
if c.conn == nil {
|
|
||||||
err = errors.New("The client's config does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
return c.conn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Implement a mechanism to attach a security token
|
|
||||||
func (c *WgCtrlConnection) CreateAuthContext(meshId string) (context.Context, error) {
|
|
||||||
token, ok := c.tokens[meshId]
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("MeshID: " + meshId + " does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
return metadata.AppendToOutgoingContext(ctx, "authorization", token), nil
|
|
||||||
}
|
|
@ -1,93 +0,0 @@
|
|||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ConnectionManager interface {
|
|
||||||
AddConnection(endPoint string) (PeerConnection, error)
|
|
||||||
GetConnection(endPoint string) (PeerConnection, error)
|
|
||||||
HasConnection(endPoint string) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnectionManager manages connections between other peers
|
|
||||||
// in the control plane.
|
|
||||||
type JwtConnectionManager struct {
|
|
||||||
// clientConnections maps an endpoint to a connection
|
|
||||||
clientConnections map[string]PeerConnection
|
|
||||||
serverConfig *tls.Config
|
|
||||||
clientConfig *tls.Config
|
|
||||||
}
|
|
||||||
|
|
||||||
type NewJwtConnectionManagerParams struct {
|
|
||||||
CertificatePath string
|
|
||||||
PrivateKey string
|
|
||||||
SkipCertVerification bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewJwtConnectionManager(params *NewJwtConnectionManagerParams) (ConnectionManager, error) {
|
|
||||||
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error())
|
|
||||||
logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath)
|
|
||||||
logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
serverAuth := tls.RequireAndVerifyClientCert
|
|
||||||
|
|
||||||
if params.SkipCertVerification {
|
|
||||||
serverAuth = tls.RequireAnyClientCert
|
|
||||||
}
|
|
||||||
|
|
||||||
serverConfig := &tls.Config{
|
|
||||||
ClientAuth: serverAuth,
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
}
|
|
||||||
|
|
||||||
clientConfig := &tls.Config{
|
|
||||||
Certificates: []tls.Certificate{cert},
|
|
||||||
InsecureSkipVerify: params.SkipCertVerification,
|
|
||||||
}
|
|
||||||
|
|
||||||
connections := make(map[string]PeerConnection)
|
|
||||||
connMgr := JwtConnectionManager{connections, serverConfig, clientConfig}
|
|
||||||
return &connMgr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JwtConnectionManager) GetConnection(endpoint string) (PeerConnection, error) {
|
|
||||||
conn, exists := m.clientConnections[endpoint]
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
return nil, errors.New("endpoint: " + endpoint + " does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddToken: Adds a connection to the list of connections to manage
|
|
||||||
func (m *JwtConnectionManager) AddConnection(endPoint string) (PeerConnection, error) {
|
|
||||||
conn, exists := m.clientConnections[endPoint]
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
m.clientConnections[endPoint] = connections
|
|
||||||
return connections, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *JwtConnectionManager) HasConnection(endPoint string) bool {
|
|
||||||
_, exists := m.clientConnections[endPoint]
|
|
||||||
return exists
|
|
||||||
}
|
|
70
pkg/conn/connection.go
Normal file
70
pkg/conn/connection.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
// conn manages gRPC connections between peers.
|
||||||
|
// Includes timers.
|
||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerConnection represents a client-side connection between two
|
||||||
|
// peers.
|
||||||
|
type PeerConnection interface {
|
||||||
|
Close() error
|
||||||
|
GetClient() (*grpc.ClientConn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WgCtrlConnection implements PeerConnection.
|
||||||
|
type WgCtrlConnection struct {
|
||||||
|
clientConfig *tls.Config
|
||||||
|
conn *grpc.ClientConn
|
||||||
|
endpoint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWgCtrlConnection creates a new instance of a WireGuard control connection
|
||||||
|
func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) {
|
||||||
|
var conn WgCtrlConnection
|
||||||
|
conn.clientConfig = clientConfig
|
||||||
|
conn.endpoint = server
|
||||||
|
|
||||||
|
if err := conn.createGrpcConn(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectWithToken: Connects to a new gRPC peer given the address of the other server.
|
||||||
|
func (c *WgCtrlConnection) createGrpcConn() error {
|
||||||
|
conn, err := grpc.Dial(c.endpoint,
|
||||||
|
grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteErrorf("Could not connect: %s\n", err.Error())
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn = conn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close: Closes the client connections
|
||||||
|
func (c *WgCtrlConnection) Close() error {
|
||||||
|
return c.conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient: Gets the client connection
|
||||||
|
func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) {
|
||||||
|
var err error = nil
|
||||||
|
|
||||||
|
if c.conn == nil {
|
||||||
|
err = errors.New("The client's config does not exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.conn, err
|
||||||
|
}
|
129
pkg/conn/connectionmanager.go
Normal file
129
pkg/conn/connectionmanager.go
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
package conn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnectionManager defines an interface for maintaining peer connections
|
||||||
|
type ConnectionManager interface {
|
||||||
|
// AddConnection adds an instance of a connection at the given endpoint
|
||||||
|
// or error if something went wrong
|
||||||
|
AddConnection(endPoint string) (PeerConnection, error)
|
||||||
|
// GetConnection returns an instance of a connection at the given endpoint.
|
||||||
|
// If the endpoint does not exist then add the connection. Returns an error
|
||||||
|
// if something went wrong
|
||||||
|
GetConnection(endPoint string) (PeerConnection, error)
|
||||||
|
// HasConnections returns true if a client has already registered at the givne
|
||||||
|
// endpoint or false otherwise.
|
||||||
|
HasConnection(endPoint string) bool
|
||||||
|
// Goes through all the connections and closes eachone
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnectionManager manages connections between other peers
|
||||||
|
// in the control plane.
|
||||||
|
type ConnectionManagerImpl struct {
|
||||||
|
// clientConnections maps an endpoint to a connection
|
||||||
|
conLoc sync.RWMutex
|
||||||
|
clientConnections map[string]PeerConnection
|
||||||
|
serverConfig *tls.Config
|
||||||
|
clientConfig *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new instance of a connection manager.
|
||||||
|
type NewConnectionManageParams struct {
|
||||||
|
// The path to the certificate
|
||||||
|
CertificatePath string
|
||||||
|
// The private key of the node
|
||||||
|
PrivateKey string
|
||||||
|
// Whether or not to skip certificate verification
|
||||||
|
SkipCertVerification bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnectionManager: Creates a new instance of a ConnectionManager or an error
|
||||||
|
// if something went wrong.
|
||||||
|
func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager, error) {
|
||||||
|
cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logging.Log.WriteErrorf("Failed to load key pair: %s\n", err.Error())
|
||||||
|
logging.Log.WriteErrorf("Certificate Path: %s\n", params.CertificatePath)
|
||||||
|
logging.Log.WriteErrorf("Private Key Path: %s\n", params.PrivateKey)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAuth := tls.RequireAndVerifyClientCert
|
||||||
|
|
||||||
|
if params.SkipCertVerification {
|
||||||
|
serverAuth = tls.RequireAnyClientCert
|
||||||
|
}
|
||||||
|
|
||||||
|
serverConfig := &tls.Config{
|
||||||
|
ClientAuth: serverAuth,
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
InsecureSkipVerify: params.SkipCertVerification,
|
||||||
|
}
|
||||||
|
|
||||||
|
connections := make(map[string]PeerConnection)
|
||||||
|
connMgr := ConnectionManagerImpl{sync.RWMutex{}, connections, serverConfig, clientConfig}
|
||||||
|
return &connMgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnection: Returns the given connection if it exists. If it does not exist then add
|
||||||
|
// the connection. Returns an error if something went wrong
|
||||||
|
func (m *ConnectionManagerImpl) GetConnection(endpoint string) (PeerConnection, error) {
|
||||||
|
m.conLoc.Lock()
|
||||||
|
conn, exists := m.clientConnections[endpoint]
|
||||||
|
m.conLoc.Unlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return m.AddConnection(endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddConnection: Adds a connection to the list of connections to manage.
|
||||||
|
func (m *ConnectionManagerImpl) AddConnection(endPoint string) (PeerConnection, error) {
|
||||||
|
m.conLoc.Lock()
|
||||||
|
conn, exists := m.clientConnections[endPoint]
|
||||||
|
m.conLoc.Unlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
connections, err := NewWgCtrlConnection(m.clientConfig, endPoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.conLoc.Lock()
|
||||||
|
m.clientConnections[endPoint] = connections
|
||||||
|
m.conLoc.Unlock()
|
||||||
|
return connections, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasConnection Returns TRUE if the given endpoint exists
|
||||||
|
func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool {
|
||||||
|
_, exists := m.clientConnections[endPoint]
|
||||||
|
return exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ConnectionManagerImpl) Close() error {
|
||||||
|
for _, conn := range m.clientConnections {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
@ -3,9 +3,7 @@ package conn
|
|||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/auth"
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/conf"
|
"github.com/tim-beatham/wgmesh/pkg/conf"
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
||||||
@ -13,17 +11,23 @@ import (
|
|||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionServer manages the gRPC server
|
// ConnectionServer manages gRPC server peer connections
|
||||||
type ConnectionServer struct {
|
type ConnectionServer struct {
|
||||||
severConfig *tls.Config
|
// tlsConfiguration of the server
|
||||||
JwtManager *auth.JwtManager
|
serverConfig *tls.Config
|
||||||
server *grpc.Server
|
// server an instance of the grpc server
|
||||||
|
server *grpc.Server
|
||||||
|
// the authentication service to authenticate nodes
|
||||||
authProvider rpc.AuthenticationServer
|
authProvider rpc.AuthenticationServer
|
||||||
|
// the ctrl service to manage node
|
||||||
ctrlProvider rpc.MeshCtrlServerServer
|
ctrlProvider rpc.MeshCtrlServerServer
|
||||||
|
// the sync service to synchronise nodes
|
||||||
syncProvider rpc.SyncServiceServer
|
syncProvider rpc.SyncServiceServer
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.WgMeshConfiguration
|
||||||
|
listener net.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewConnectionServerParams contains params for creating a new connection server
|
||||||
type NewConnectionServerParams struct {
|
type NewConnectionServerParams struct {
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.WgMeshConfiguration
|
||||||
AuthProvider rpc.AuthenticationServer
|
AuthProvider rpc.AuthenticationServer
|
||||||
@ -36,9 +40,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
|
|||||||
cert, err := tls.LoadX509KeyPair(params.Conf.CertificatePath, params.Conf.PrivateKeyPath)
|
cert, err := tls.LoadX509KeyPair(params.Conf.CertificatePath, params.Conf.PrivateKeyPath)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error())
|
logging.Log.WriteErrorf("Failed to load key pair: %s\n", err.Error())
|
||||||
logging.ErrorLog.Printf("Certificate Path: %s\n", params.Conf.CertificatePath)
|
|
||||||
logging.ErrorLog.Printf("Private Key Path: %s\n", params.Conf.PrivateKeyPath)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,10 +55,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
|
|||||||
Certificates: []tls.Certificate{cert},
|
Certificates: []tls.Certificate{cert},
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtManager := auth.NewJwtManager(params.Conf.Secret, 24*time.Hour)
|
|
||||||
|
|
||||||
server := grpc.NewServer(
|
server := grpc.NewServer(
|
||||||
grpc.UnaryInterceptor(jwtManager.GetAuthInterceptor()),
|
|
||||||
grpc.Creds(credentials.NewTLS(serverConfig)),
|
grpc.Creds(credentials.NewTLS(serverConfig)),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -65,38 +64,51 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer,
|
|||||||
syncProvider := params.SyncProvider
|
syncProvider := params.SyncProvider
|
||||||
|
|
||||||
connServer := ConnectionServer{
|
connServer := ConnectionServer{
|
||||||
serverConfig,
|
serverConfig: serverConfig,
|
||||||
jwtManager,
|
server: server,
|
||||||
server,
|
authProvider: authProvider,
|
||||||
authProvider,
|
ctrlProvider: ctrlProvider,
|
||||||
ctrlProvider,
|
syncProvider: syncProvider,
|
||||||
syncProvider,
|
Conf: params.Conf,
|
||||||
params.Conf,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &connServer, nil
|
return &connServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Listen for incoming requests. Returns an error if something went wrong.
|
||||||
func (s *ConnectionServer) Listen() error {
|
func (s *ConnectionServer) Listen() error {
|
||||||
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
|
rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider)
|
||||||
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
|
rpc.RegisterAuthenticationServer(s.server, s.authProvider)
|
||||||
|
|
||||||
logging.InfoLog.Println(s.syncProvider)
|
|
||||||
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
|
rpc.RegisterSyncServiceServer(s.server, s.syncProvider)
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort)
|
lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort)
|
||||||
|
s.listener = lis
|
||||||
|
|
||||||
logging.InfoLog.Printf("GRPC listening on %s\n", s.Conf.GrpcPort)
|
logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Println(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.server.Serve(lis); err != nil {
|
if err := s.server.Serve(lis); err != nil {
|
||||||
logging.ErrorLog.Println(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the connection server. Returns an error
|
||||||
|
// if something went wrong whilst attempting to close the connection
|
||||||
|
func (c *ConnectionServer) Close() error {
|
||||||
|
var err error = nil
|
||||||
|
c.server.Stop()
|
||||||
|
|
||||||
|
if c.listener != nil {
|
||||||
|
err = c.listener.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
@ -1,8 +1,3 @@
|
|||||||
/*
|
|
||||||
* ctrlserver controls the WireGuard mesh. Contains an IpcHandler for
|
|
||||||
* handling commands fired by wgmesh command.
|
|
||||||
* Contains an RpcHandler for handling commands fired by another server.
|
|
||||||
*/
|
|
||||||
package ctrlserver
|
package ctrlserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -13,6 +8,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NewCtrlServerParams are the params requried to create a new ctrl server
|
||||||
type NewCtrlServerParams struct {
|
type NewCtrlServerParams struct {
|
||||||
WgClient *wgctrl.Client
|
WgClient *wgctrl.Client
|
||||||
Conf *conf.WgMeshConfiguration
|
Conf *conf.WgMeshConfiguration
|
||||||
@ -21,32 +17,27 @@ type NewCtrlServerParams struct {
|
|||||||
SyncProvider rpc.SyncServiceServer
|
SyncProvider rpc.SyncServiceServer
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
// Create a new instance of the MeshCtrlServer or error if the
|
||||||
* NewCtrlServer creates a new instance of the ctrlserver.
|
// operation failed
|
||||||
* It is associated with a WireGuard client and an interface.
|
|
||||||
* wgClient: Represents the WireGuard control client.
|
|
||||||
* ifName: WireGuard interface name
|
|
||||||
*/
|
|
||||||
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
||||||
ctrlServer := new(MeshCtrlServer)
|
ctrlServer := new(MeshCtrlServer)
|
||||||
ctrlServer.Client = params.WgClient
|
ctrlServer.Client = params.WgClient
|
||||||
ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf)
|
ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf)
|
||||||
ctrlServer.Conf = params.Conf
|
ctrlServer.Conf = params.Conf
|
||||||
|
|
||||||
connManagerParams := conn.NewJwtConnectionManagerParams{
|
connManagerParams := conn.NewConnectionManageParams{
|
||||||
CertificatePath: params.Conf.CertificatePath,
|
CertificatePath: params.Conf.CertificatePath,
|
||||||
PrivateKey: params.Conf.PrivateKeyPath,
|
PrivateKey: params.Conf.PrivateKeyPath,
|
||||||
SkipCertVerification: params.Conf.SkipCertVerification,
|
SkipCertVerification: params.Conf.SkipCertVerification,
|
||||||
}
|
}
|
||||||
|
|
||||||
connMgr, err := conn.NewJwtConnectionManager(&connManagerParams)
|
connMgr, err := conn.NewConnectionManager(&connManagerParams)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctrlServer.ConnectionManager = connMgr
|
ctrlServer.ConnectionManager = connMgr
|
||||||
|
|
||||||
connServerParams := conn.NewConnectionServerParams{
|
connServerParams := conn.NewConnectionServerParams{
|
||||||
Conf: params.Conf,
|
Conf: params.Conf,
|
||||||
AuthProvider: params.AuthProvider,
|
AuthProvider: params.AuthProvider,
|
||||||
@ -63,3 +54,16 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) {
|
|||||||
ctrlServer.ConnectionServer = connServer
|
ctrlServer.ConnectionServer = connServer
|
||||||
return ctrlServer, nil
|
return ctrlServer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the ctrl server tearing down any connections that exist
|
||||||
|
func (s *MeshCtrlServer) Close() error {
|
||||||
|
if err := s.ConnectionManager.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.ConnectionServer.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -1,22 +1,51 @@
|
|||||||
|
// Provides a generic interface for logging
|
||||||
package logging
|
package logging
|
||||||
|
|
||||||
/*
|
|
||||||
* This package creates the info, warning and error loggers.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
InfoLog *log.Logger
|
Log Logger
|
||||||
WarningLog *log.Logger
|
|
||||||
ErrorLog *log.Logger
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
type Logger interface {
|
||||||
InfoLog = log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile)
|
WriteInfof(msg string, args ...interface{})
|
||||||
WarningLog = log.New(os.Stdout, "WARNING: ", log.Ldate|log.Ltime|log.Lshortfile)
|
WriteErrorf(msg string, args ...interface{})
|
||||||
ErrorLog = log.New(os.Stderr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile)
|
WriteWarnf(msg string, args ...interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type LogrusLogger struct {
|
||||||
|
logger *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LogrusLogger) WriteInfof(msg string, args ...interface{}) {
|
||||||
|
l.logger.Infof(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LogrusLogger) WriteErrorf(msg string, args ...interface{}) {
|
||||||
|
l.logger.Errorf(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) {
|
||||||
|
l.logger.Warnf(msg, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogrusLogger() *LogrusLogger {
|
||||||
|
logger := logrus.New()
|
||||||
|
logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true})
|
||||||
|
logger.SetOutput(os.Stdout)
|
||||||
|
logger.SetLevel(logrus.InfoLevel)
|
||||||
|
|
||||||
|
return &LogrusLogger{logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
SetLogger(NewLogrusLogger())
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetLogger(l Logger) {
|
||||||
|
Log = l
|
||||||
}
|
}
|
||||||
|
@ -4,16 +4,17 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/tim-beatham/wgmesh/pkg/auth"
|
|
||||||
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
logging "github.com/tim-beatham/wgmesh/pkg/log"
|
||||||
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
"github.com/tim-beatham/wgmesh/pkg/rpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// AuthRpcProvider implements the AuthRpcProvider service
|
||||||
type AuthRpcProvider struct {
|
type AuthRpcProvider struct {
|
||||||
rpc.UnimplementedAuthenticationServer
|
rpc.UnimplementedAuthenticationServer
|
||||||
Manager *auth.JwtManager
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JoinMesh handles a JoinMeshRequest. Succeeds by stating the node managed to join the mesh
|
||||||
|
// or returns an error if it failed
|
||||||
func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) {
|
func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) {
|
||||||
meshId := in.MeshId
|
meshId := in.MeshId
|
||||||
|
|
||||||
@ -21,12 +22,8 @@ func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequ
|
|||||||
return nil, errors.New("Must specify the meshId")
|
return nil, errors.New("Must specify the meshId")
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.InfoLog.Println("MeshID: " + in.MeshId)
|
logging.Log.WriteInfof("MeshID: " + in.MeshId)
|
||||||
token, err := a.Manager.CreateClaims(in.MeshId, in.Alias)
|
|
||||||
|
|
||||||
if err != nil {
|
var token string = ""
|
||||||
return nil, err
|
return &rpc.JoinAuthMeshReply{Success: true, Token: &token}, nil
|
||||||
}
|
|
||||||
|
|
||||||
return &rpc.JoinAuthMeshReply{Success: true, Token: token}, nil
|
|
||||||
}
|
}
|
||||||
|
@ -75,71 +75,9 @@ func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n *RobinIpc) Authenticate(meshId, endpoint string) error {
|
|
||||||
peerConnection, err := n.Server.ConnectionManager.AddConnection(endpoint)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = peerConnection.Authenticate(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *RobinIpc) authenticatePeers(meshId string) error {
|
|
||||||
theMesh := n.Server.MeshManager.GetMesh(meshId)
|
|
||||||
|
|
||||||
if theMesh == nil {
|
|
||||||
return errors.New("the mesh does not exist")
|
|
||||||
}
|
|
||||||
|
|
||||||
snapshot, _ := theMesh.GetCrdt()
|
|
||||||
publicKey, err := n.Server.MeshManager.GetPublicKey(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for nodeKey, node := range snapshot.Nodes {
|
|
||||||
logging.InfoLog.Println(nodeKey)
|
|
||||||
if nodeKey == publicKey.String() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err := n.Authenticate(meshId, node.HostEndpoint)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
||||||
err := n.Authenticate(args.MeshId, args.IpAdress)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
|
peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = peerConnection.Connect()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := peerConnection.GetClient()
|
client, err := peerConnection.GetClient()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -148,13 +86,11 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
|||||||
|
|
||||||
c := rpc.NewMeshCtrlServerClient(client)
|
c := rpc.NewMeshCtrlServerClient(client)
|
||||||
|
|
||||||
authContext, err := peerConnection.CreateAuthContext(args.MeshId)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(authContext, time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
|
meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId})
|
||||||
@ -181,7 +117,7 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.InfoLog.Println("WgIP: " + ipAddr.String())
|
logging.Log.WriteInfof("WgIP: " + ipAddr.String())
|
||||||
|
|
||||||
outBoundIP := lib.GetOutboundIP()
|
outBoundIP := lib.GetOutboundIP()
|
||||||
|
|
||||||
@ -206,10 +142,6 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if joinReply.GetSuccess() {
|
|
||||||
err = n.authenticatePeers(args.MeshId)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -56,7 +56,7 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r
|
|||||||
func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
|
func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) {
|
||||||
mesh := m.Server.MeshManager.GetMesh(request.MeshId)
|
mesh := m.Server.MeshManager.GetMesh(request.MeshId)
|
||||||
|
|
||||||
logging.InfoLog.Println("[JOINING MESH]: " + request.MeshId)
|
logging.Log.WriteInfof("[JOINING MESH]: " + request.MeshId)
|
||||||
|
|
||||||
if mesh == nil {
|
if mesh == nil {
|
||||||
return nil, errors.New("mesh does not exist")
|
return nil, errors.New("mesh does not exist")
|
||||||
|
@ -34,7 +34,7 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri
|
|||||||
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
|
func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool {
|
||||||
errStatus, _ := status.FromError(err)
|
errStatus, _ := status.FromError(err)
|
||||||
|
|
||||||
logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message())
|
logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message())
|
||||||
|
|
||||||
switch errStatus.Code() {
|
switch errStatus.Code() {
|
||||||
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound:
|
||||||
|
@ -23,22 +23,6 @@ type SyncRequesterImpl struct {
|
|||||||
errorHdlr SyncErrorHandler
|
errorHdlr SyncErrorHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error {
|
|
||||||
peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = peerConnection.Authenticate(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMesh: Retrieves the local state of the mesh at the endpoint
|
// GetMesh: Retrieves the local state of the mesh at the endpoint
|
||||||
func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
|
func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
|
||||||
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint)
|
peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint)
|
||||||
@ -47,12 +31,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = peerConnection.Connect()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := peerConnection.GetClient()
|
client, err := peerConnection.GetClient()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -60,13 +38,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c := rpc.NewSyncServiceClient(client)
|
c := rpc.NewSyncServiceClient(client)
|
||||||
authContext, err := peerConnection.CreateAuthContext(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(authContext, time.Second)
|
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId})
|
reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId})
|
||||||
@ -91,34 +64,18 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error
|
|||||||
|
|
||||||
// SyncMesh: Proactively send a sync request to the other mesh
|
// SyncMesh: Proactively send a sync request to the other mesh
|
||||||
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
||||||
if !s.server.ConnectionManager.HasConnection(endpoint) {
|
|
||||||
s.Authenticate(meshId, endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
|
peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = peerConnection.Connect()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return s.handleErr(meshId, endpoint, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := peerConnection.GetClient()
|
client, err := peerConnection.GetClient()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
authContext, err := peerConnection.CreateAuthContext(meshId)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
mesh := s.server.MeshManager.GetMesh(meshId)
|
mesh := s.server.MeshManager.GetMesh(meshId)
|
||||||
|
|
||||||
if mesh == nil {
|
if mesh == nil {
|
||||||
@ -127,7 +84,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
|||||||
|
|
||||||
c := rpc.NewSyncServiceClient(client)
|
c := rpc.NewSyncServiceClient(client)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(authContext, 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err = syncMesh(mesh, ctx, c)
|
err = syncMesh(mesh, ctx, c)
|
||||||
@ -136,7 +93,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error {
|
|||||||
return s.handleErr(meshId, endpoint, err)
|
return s.handleErr(meshId, endpoint, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId)
|
||||||
mesh.DecrementFailedCount(endpoint)
|
mesh.DecrementFailedCount(endpoint)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -162,7 +119,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
|
|||||||
in, err := stream.Recv()
|
in, err := stream.Recv()
|
||||||
|
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logging.ErrorLog.Printf("Stream recv error: %s\n", err.Error())
|
logging.Log.WriteInfof("Stream recv error: %s\n", err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,7 +128,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Printf("Syncer recv error: %s\n", err.Error())
|
logging.Log.WriteInfof("Syncer recv error: %s\n", err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -180,7 +137,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.InfoLog.Println("SYNC finished")
|
logging.Log.WriteInfof("SYNC finished")
|
||||||
stream.CloseSend()
|
stream.CloseSend()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func (s *SyncSchedulerImpl) Run() error {
|
|||||||
err := s.syncer.SyncMeshes()
|
err := s.syncer.SyncMeshes()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logging.ErrorLog.Println(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case <-quit:
|
case <-quit:
|
||||||
|
@ -41,9 +41,9 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
|
|||||||
var syncer *crdt.AutomergeSync = nil
|
var syncer *crdt.AutomergeSync = nil
|
||||||
|
|
||||||
for {
|
for {
|
||||||
logging.InfoLog.Println("Received Attempt")
|
logging.Log.WriteInfof("Received Attempt")
|
||||||
in, err := stream.Recv()
|
in, err := stream.Recv()
|
||||||
logging.InfoLog.Println("Received Worked")
|
logging.Log.WriteInfof("Received Worked")
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
return nil
|
return nil
|
||||||
@ -84,7 +84,6 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !moreMessages || err == io.EOF {
|
if !moreMessages || err == io.EOF {
|
||||||
logging.InfoLog.Println("SYNC Completed")
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,7 +62,7 @@ func EnableInterface(ifName string, ip string) error {
|
|||||||
cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName)
|
cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName)
|
||||||
|
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
logging.ErrorLog.Println(err.Error())
|
logging.Log.WriteErrorf(err.Error())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user