From 8e89281484bd5fcb40fb0b262a52fc1116f3291f Mon Sep 17 00:00:00 2001 From: Tim Beatham Date: Tue, 24 Oct 2023 00:12:38 +0100 Subject: [PATCH] Tested with large number of nodes --- cmd/wg-mesh/main.go | 2 +- cmd/wgmeshd/main.go | 11 +- pkg/auth/jwt.go | 140 ------------------ pkg/automerge/automerge.go | 10 +- pkg/conf/conf.go | 4 +- pkg/conn/conn.go | 116 --------------- pkg/conn/conn_manager.go | 93 ------------ pkg/conn/connection.go | 70 +++++++++ pkg/conn/connectionmanager.go | 129 ++++++++++++++++ .../{conn_server.go => connectionserver.go} | 58 +++++--- pkg/ctrlserver/ctrlserver.go | 32 ++-- pkg/log/log.go | 53 +++++-- pkg/middleware/auth.go | 15 +- pkg/robin/robin_requester.go | 72 +-------- pkg/robin/robin_responder.go | 2 +- pkg/sync/syncererror.go | 2 +- pkg/sync/syncrequester.go | 55 +------ pkg/sync/syncscheduler.go | 2 +- pkg/sync/syncservice.go | 5 +- pkg/wg/wg.go | 2 +- 20 files changed, 326 insertions(+), 547 deletions(-) delete mode 100644 pkg/auth/jwt.go delete mode 100644 pkg/conn/conn.go delete mode 100644 pkg/conn/conn_manager.go create mode 100644 pkg/conn/connection.go create mode 100644 pkg/conn/connectionmanager.go rename pkg/conn/{conn_server.go => connectionserver.go} (60%) diff --git a/cmd/wg-mesh/main.go b/cmd/wg-mesh/main.go index c0d1220..5f66268 100644 --- a/cmd/wg-mesh/main.go +++ b/cmd/wg-mesh/main.go @@ -31,7 +31,7 @@ func listMeshes(client *ipcRpc.Client) { err := client.Call("RobinIpc.ListMeshes", "", &reply) if err != nil { - logging.ErrorLog.Println(err.Error()) + logging.Log.WriteErrorf(err.Error()) return } diff --git a/cmd/wgmeshd/main.go b/cmd/wgmeshd/main.go index 82defbd..3e9ef25 100644 --- a/cmd/wgmeshd/main.go +++ b/cmd/wgmeshd/main.go @@ -36,7 +36,6 @@ func main() { } ctrlServer, err := ctrlserver.NewCtrlServer(&ctrlServerParams) - authProvider.Manager = ctrlServer.ConnectionServer.JwtManager syncProvider.Server = ctrlServer syncRequester := sync.NewSyncRequester(ctrlServer) syncScheduler := sync.NewSyncScheduler(ctrlServer, syncRequester, 2) @@ -50,7 +49,8 @@ func main() { robinIpc = robin.NewRobinIpc(robinIpcParams) if err != nil { - logging.ErrorLog.Fatalln(err.Error()) + logging.Log.WriteErrorf(err.Error()) + return } log.Println("Running IPC Handler") @@ -61,9 +61,12 @@ func main() { err = ctrlServer.ConnectionServer.Listen() if err != nil { - logging.ErrorLog.Fatalln(err.Error()) + logging.Log.WriteErrorf(err.Error()) + + return } - defer wgClient.Close() defer syncScheduler.Stop() + defer ctrlServer.Close() + defer wgClient.Close() } diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go deleted file mode 100644 index 4947847..0000000 --- a/pkg/auth/jwt.go +++ /dev/null @@ -1,140 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "strconv" - "strings" - "time" - - "github.com/golang-jwt/jwt/v5" - logging "github.com/tim-beatham/wgmesh/pkg/log" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" -) - -// JwtMesh contains all the sessions with the mesh network -type JwtMesh struct { - meshId string - // nodes contains a set of nodes with the string being the jwt token - nodes map[string]interface{} -} - -// JwtManager manages jwt tokens indicating a session -// between this host and another within a specific mesh -type JwtManager struct { - secretKey []byte - tokenDuration time.Duration - // meshes contains all the meshes that we have sessions with - meshes map[string]*JwtMesh -} - -// JwtNode represents a jwt node in the mesh network -type JwtNode struct { - MeshId string `json:"meshId"` - Alias string `json:"alias"` - jwt.RegisteredClaims -} - -func NewJwtManager(secretKey string, tokenDuration time.Duration) *JwtManager { - meshes := make(map[string]*JwtMesh) - return &JwtManager{[]byte(secretKey), tokenDuration, meshes} -} - -func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) { - logging.InfoLog.Println("MeshID: " + meshId) - logging.InfoLog.Println("Token Duration: " + strconv.Itoa(int(m.tokenDuration))) - node := JwtNode{ - MeshId: meshId, - Alias: alias, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(time.Now().Add(m.tokenDuration)), - }, - } - - mesh, contains := m.meshes[meshId] - - if !contains { - mesh = new(JwtMesh) - mesh.meshId = meshId - mesh.nodes = make(map[string]interface{}) - mesh.nodes[meshId] = mesh - } - - token := jwt.NewWithClaims(jwt.SigningMethodHS256, node) - signedString, err := token.SignedString(m.secretKey) - - if err != nil { - fmt.Println(err.Error()) - return nil, err - } - - _, exists := mesh.nodes[signedString] - - if exists { - return nil, errors.New("Node already exists") - } - - mesh.nodes[signedString] = struct{}{} - return &signedString, nil -} - -func (m *JwtManager) Verify(accessToken string) (*JwtNode, bool) { - token, err := jwt.ParseWithClaims(accessToken, &JwtNode{}, func(t *jwt.Token) (interface{}, error) { - return m.secretKey, nil - }) - - if err != nil { - return nil, false - } - - if !token.Valid { - return nil, token.Valid - } - - claims, ok := token.Claims.(*JwtNode) - return claims, ok -} - -func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor { - return func( - ctx context.Context, - req interface{}, - info *grpc.UnaryServerInfo, - handler grpc.UnaryHandler, - ) (interface{}, error) { - - if strings.Contains(info.FullMethod, "") { - return handler(ctx, req) - } - - md, ok := metadata.FromIncomingContext(ctx) - - if !ok { - return nil, status.Errorf(codes.Unauthenticated, "metadata is not provided") - } - - values := md["authorization"] - - for _, w := range values { - logging.InfoLog.Printf(w) - } - - if len(values) == 0 { - return nil, status.Errorf(codes.Unauthenticated, "authorization token is not provided") - } - - acessToken := values[0] - - _, valid := m.Verify(acessToken) - - if !valid { - return nil, status.Errorf(codes.Unauthenticated, "Invalid access token: %s", acessToken) - } - - return handler(ctx, req) - } -} diff --git a/pkg/automerge/automerge.go b/pkg/automerge/automerge.go index 2816702..0f1176d 100644 --- a/pkg/automerge/automerge.go +++ b/pkg/automerge/automerge.go @@ -193,14 +193,13 @@ func (m *CrdtNodeManager) Length() int { return m.doc.Path("nodes").Map().Len() } -const threshold = 2 const thresholdVotes = 0.1 func (m *CrdtNodeManager) HasFailed(endpoint string) bool { node, err := m.GetNode(endpoint) if err != nil { - logging.InfoLog.Printf("Cannot get node node: %s\n", endpoint) + logging.Log.WriteErrorf("Cannot get node node: %s\n", endpoint) return true } @@ -215,14 +214,12 @@ func (m *CrdtNodeManager) HasFailed(endpoint string) bool { for _, value := range values { count := value.Int64() - if count >= threshold { + if count >= 1 { countFailed++ } } - logging.InfoLog.Printf("Count Failed Value: %d\n", countFailed) - logging.InfoLog.Printf("Threshold Value: %d\n", int(thresholdVotes*float64(m.Length())+1)) - return countFailed >= int(thresholdVotes*float64(m.Length())+1) + return countFailed >= 4 } func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNodeCrdt, client wgctrl.Client) error { @@ -232,7 +229,6 @@ func (m *CrdtNodeManager) updateWgConf(devName string, nodes map[string]MeshNode for _, n := range nodes { peer, err := m.convertMeshNode(n) - logging.InfoLog.Println(n.HostEndpoint) if err != nil { return err diff --git a/pkg/conf/conf.go b/pkg/conf/conf.go index 8fd93f6..4a34ca0 100644 --- a/pkg/conf/conf.go +++ b/pkg/conf/conf.go @@ -24,14 +24,14 @@ func ParseConfiguration(filePath string) (*WgMeshConfiguration, error) { yamlBytes, err := os.ReadFile(filePath) if err != nil { - logging.ErrorLog.Printf("Read file error: %s\n", err.Error()) + logging.Log.WriteErrorf("Read file error: %s\n", err.Error()) return nil, err } err = yaml.Unmarshal(yamlBytes, &conf) if err != nil { - logging.ErrorLog.Printf("Unmarshal error: %s\n", err.Error()) + logging.Log.WriteErrorf("Unmarshal error: %s\n", err.Error()) return nil, err } diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go deleted file mode 100644 index b884055..0000000 --- a/pkg/conn/conn.go +++ /dev/null @@ -1,116 +0,0 @@ -// conn manages gRPC connections between peers. -// Includes timers. -package conn - -import ( - "context" - "crypto/tls" - "errors" - "time" - - "github.com/tim-beatham/wgmesh/pkg/lib" - logging "github.com/tim-beatham/wgmesh/pkg/log" - "github.com/tim-beatham/wgmesh/pkg/rpc" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/metadata" -) - -// PeerConnection interfacing for a secure connection between -// two peers. -type PeerConnection interface { - Connect() error - Close() error - Authenticate(meshId string) error - GetClient() (*grpc.ClientConn, error) - CreateAuthContext(meshId string) (context.Context, error) -} - -type WgCtrlConnection struct { - clientConfig *tls.Config - conn *grpc.ClientConn - endpoint string - // tokens maps a meshID to the corresponding token - tokens map[string]string -} - -func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) { - var conn WgCtrlConnection - conn.tokens = make(map[string]string) - conn.clientConfig = clientConfig - conn.endpoint = server - return &conn, nil -} - -func (c *WgCtrlConnection) Authenticate(meshId string) error { - conn, err := grpc.Dial(c.endpoint, - grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig))) - - defer conn.Close() - - if err != nil { - return err - } - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - - client := rpc.NewAuthenticationClient(conn) - defer cancel() - - authRequest := rpc.JoinAuthMeshRequest{ - MeshId: meshId, - Alias: lib.GetOutboundIP().String(), - } - - reply, err := client.JoinMesh(ctx, &authRequest) - - if err != nil { - return err - } - - c.tokens[meshId] = *reply.Token - return nil -} - -// ConnectWithToken: Connects to a new gRPC peer given the address of the other server. -func (c *WgCtrlConnection) Connect() error { - conn, err := grpc.Dial(c.endpoint, - grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)), - ) - - if err != nil { - logging.ErrorLog.Printf("Could not connect: %s\n", err.Error()) - return err - } - - c.conn = conn - return nil -} - -// Close: Closes the client connections -func (c *WgCtrlConnection) Close() error { - return c.conn.Close() -} - -// GetClient: Gets the client connection -func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) { - var err error = nil - - if c.conn == nil { - err = errors.New("The client's config does not exist") - } - - return c.conn, err -} - -// TODO: Implement a mechanism to attach a security token -func (c *WgCtrlConnection) CreateAuthContext(meshId string) (context.Context, error) { - token, ok := c.tokens[meshId] - - if !ok { - return nil, errors.New("MeshID: " + meshId + " does not exist") - } - - ctx := context.Background() - return metadata.AppendToOutgoingContext(ctx, "authorization", token), nil -} diff --git a/pkg/conn/conn_manager.go b/pkg/conn/conn_manager.go deleted file mode 100644 index 8045180..0000000 --- a/pkg/conn/conn_manager.go +++ /dev/null @@ -1,93 +0,0 @@ -package conn - -import ( - "crypto/tls" - "errors" - - logging "github.com/tim-beatham/wgmesh/pkg/log" -) - -type ConnectionManager interface { - AddConnection(endPoint string) (PeerConnection, error) - GetConnection(endPoint string) (PeerConnection, error) - HasConnection(endPoint string) bool -} - -// ConnectionManager manages connections between other peers -// in the control plane. -type JwtConnectionManager struct { - // clientConnections maps an endpoint to a connection - clientConnections map[string]PeerConnection - serverConfig *tls.Config - clientConfig *tls.Config -} - -type NewJwtConnectionManagerParams struct { - CertificatePath string - PrivateKey string - SkipCertVerification bool -} - -func NewJwtConnectionManager(params *NewJwtConnectionManagerParams) (ConnectionManager, error) { - cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey) - - if err != nil { - logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error()) - logging.ErrorLog.Printf("Certificate Path: %s\n", params.CertificatePath) - logging.ErrorLog.Printf("Private Key Path: %s\n", params.PrivateKey) - return nil, err - } - - serverAuth := tls.RequireAndVerifyClientCert - - if params.SkipCertVerification { - serverAuth = tls.RequireAnyClientCert - } - - serverConfig := &tls.Config{ - ClientAuth: serverAuth, - Certificates: []tls.Certificate{cert}, - } - - clientConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - InsecureSkipVerify: params.SkipCertVerification, - } - - connections := make(map[string]PeerConnection) - connMgr := JwtConnectionManager{connections, serverConfig, clientConfig} - return &connMgr, nil -} - -func (m *JwtConnectionManager) GetConnection(endpoint string) (PeerConnection, error) { - conn, exists := m.clientConnections[endpoint] - - if !exists { - return nil, errors.New("endpoint: " + endpoint + " does not exist") - } - - return conn, nil -} - -// AddToken: Adds a connection to the list of connections to manage -func (m *JwtConnectionManager) AddConnection(endPoint string) (PeerConnection, error) { - conn, exists := m.clientConnections[endPoint] - - if exists { - return conn, nil - } - - connections, err := NewWgCtrlConnection(m.clientConfig, endPoint) - - if err != nil { - return nil, err - } - - m.clientConnections[endPoint] = connections - return connections, nil -} - -func (m *JwtConnectionManager) HasConnection(endPoint string) bool { - _, exists := m.clientConnections[endPoint] - return exists -} diff --git a/pkg/conn/connection.go b/pkg/conn/connection.go new file mode 100644 index 0000000..2b84a78 --- /dev/null +++ b/pkg/conn/connection.go @@ -0,0 +1,70 @@ +// conn manages gRPC connections between peers. +// Includes timers. +package conn + +import ( + "crypto/tls" + "errors" + + logging "github.com/tim-beatham/wgmesh/pkg/log" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +// PeerConnection represents a client-side connection between two +// peers. +type PeerConnection interface { + Close() error + GetClient() (*grpc.ClientConn, error) +} + +// WgCtrlConnection implements PeerConnection. +type WgCtrlConnection struct { + clientConfig *tls.Config + conn *grpc.ClientConn + endpoint string +} + +// NewWgCtrlConnection creates a new instance of a WireGuard control connection +func NewWgCtrlConnection(clientConfig *tls.Config, server string) (*WgCtrlConnection, error) { + var conn WgCtrlConnection + conn.clientConfig = clientConfig + conn.endpoint = server + + if err := conn.createGrpcConn(); err != nil { + return nil, err + } + + return &conn, nil +} + +// ConnectWithToken: Connects to a new gRPC peer given the address of the other server. +func (c *WgCtrlConnection) createGrpcConn() error { + conn, err := grpc.Dial(c.endpoint, + grpc.WithTransportCredentials(credentials.NewTLS(c.clientConfig)), + ) + + if err != nil { + logging.Log.WriteErrorf("Could not connect: %s\n", err.Error()) + return err + } + + c.conn = conn + return nil +} + +// Close: Closes the client connections +func (c *WgCtrlConnection) Close() error { + return c.conn.Close() +} + +// GetClient: Gets the client connection +func (c *WgCtrlConnection) GetClient() (*grpc.ClientConn, error) { + var err error = nil + + if c.conn == nil { + err = errors.New("The client's config does not exist") + } + + return c.conn, err +} diff --git a/pkg/conn/connectionmanager.go b/pkg/conn/connectionmanager.go new file mode 100644 index 0000000..10c8030 --- /dev/null +++ b/pkg/conn/connectionmanager.go @@ -0,0 +1,129 @@ +package conn + +import ( + "crypto/tls" + "sync" + + logging "github.com/tim-beatham/wgmesh/pkg/log" +) + +// ConnectionManager defines an interface for maintaining peer connections +type ConnectionManager interface { + // AddConnection adds an instance of a connection at the given endpoint + // or error if something went wrong + AddConnection(endPoint string) (PeerConnection, error) + // GetConnection returns an instance of a connection at the given endpoint. + // If the endpoint does not exist then add the connection. Returns an error + // if something went wrong + GetConnection(endPoint string) (PeerConnection, error) + // HasConnections returns true if a client has already registered at the givne + // endpoint or false otherwise. + HasConnection(endPoint string) bool + // Goes through all the connections and closes eachone + Close() error +} + +// ConnectionManager manages connections between other peers +// in the control plane. +type ConnectionManagerImpl struct { + // clientConnections maps an endpoint to a connection + conLoc sync.RWMutex + clientConnections map[string]PeerConnection + serverConfig *tls.Config + clientConfig *tls.Config +} + +// Create a new instance of a connection manager. +type NewConnectionManageParams struct { + // The path to the certificate + CertificatePath string + // The private key of the node + PrivateKey string + // Whether or not to skip certificate verification + SkipCertVerification bool +} + +// NewConnectionManager: Creates a new instance of a ConnectionManager or an error +// if something went wrong. +func NewConnectionManager(params *NewConnectionManageParams) (ConnectionManager, error) { + cert, err := tls.LoadX509KeyPair(params.CertificatePath, params.PrivateKey) + + if err != nil { + logging.Log.WriteErrorf("Failed to load key pair: %s\n", err.Error()) + logging.Log.WriteErrorf("Certificate Path: %s\n", params.CertificatePath) + logging.Log.WriteErrorf("Private Key Path: %s\n", params.PrivateKey) + return nil, err + } + + serverAuth := tls.RequireAndVerifyClientCert + + if params.SkipCertVerification { + serverAuth = tls.RequireAnyClientCert + } + + serverConfig := &tls.Config{ + ClientAuth: serverAuth, + Certificates: []tls.Certificate{cert}, + } + + clientConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: params.SkipCertVerification, + } + + connections := make(map[string]PeerConnection) + connMgr := ConnectionManagerImpl{sync.RWMutex{}, connections, serverConfig, clientConfig} + return &connMgr, nil +} + +// GetConnection: Returns the given connection if it exists. If it does not exist then add +// the connection. Returns an error if something went wrong +func (m *ConnectionManagerImpl) GetConnection(endpoint string) (PeerConnection, error) { + m.conLoc.Lock() + conn, exists := m.clientConnections[endpoint] + m.conLoc.Unlock() + + if !exists { + return m.AddConnection(endpoint) + } + + return conn, nil +} + +// AddConnection: Adds a connection to the list of connections to manage. +func (m *ConnectionManagerImpl) AddConnection(endPoint string) (PeerConnection, error) { + m.conLoc.Lock() + conn, exists := m.clientConnections[endPoint] + m.conLoc.Unlock() + + if exists { + return conn, nil + } + + connections, err := NewWgCtrlConnection(m.clientConfig, endPoint) + + if err != nil { + return nil, err + } + + m.conLoc.Lock() + m.clientConnections[endPoint] = connections + m.conLoc.Unlock() + return connections, nil +} + +// HasConnection Returns TRUE if the given endpoint exists +func (m *ConnectionManagerImpl) HasConnection(endPoint string) bool { + _, exists := m.clientConnections[endPoint] + return exists +} + +func (m *ConnectionManagerImpl) Close() error { + for _, conn := range m.clientConnections { + if err := conn.Close(); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/conn/conn_server.go b/pkg/conn/connectionserver.go similarity index 60% rename from pkg/conn/conn_server.go rename to pkg/conn/connectionserver.go index 3c32dd2..44081ad 100644 --- a/pkg/conn/conn_server.go +++ b/pkg/conn/connectionserver.go @@ -3,9 +3,7 @@ package conn import ( "crypto/tls" "net" - "time" - "github.com/tim-beatham/wgmesh/pkg/auth" "github.com/tim-beatham/wgmesh/pkg/conf" logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/rpc" @@ -13,17 +11,23 @@ import ( "google.golang.org/grpc/credentials" ) -// ConnectionServer manages the gRPC server +// ConnectionServer manages gRPC server peer connections type ConnectionServer struct { - severConfig *tls.Config - JwtManager *auth.JwtManager - server *grpc.Server + // tlsConfiguration of the server + serverConfig *tls.Config + // server an instance of the grpc server + server *grpc.Server + // the authentication service to authenticate nodes authProvider rpc.AuthenticationServer + // the ctrl service to manage node ctrlProvider rpc.MeshCtrlServerServer + // the sync service to synchronise nodes syncProvider rpc.SyncServiceServer Conf *conf.WgMeshConfiguration + listener net.Listener } +// NewConnectionServerParams contains params for creating a new connection server type NewConnectionServerParams struct { Conf *conf.WgMeshConfiguration AuthProvider rpc.AuthenticationServer @@ -36,9 +40,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, cert, err := tls.LoadX509KeyPair(params.Conf.CertificatePath, params.Conf.PrivateKeyPath) if err != nil { - logging.ErrorLog.Printf("Failed to load key pair: %s\n", err.Error()) - logging.ErrorLog.Printf("Certificate Path: %s\n", params.Conf.CertificatePath) - logging.ErrorLog.Printf("Private Key Path: %s\n", params.Conf.PrivateKeyPath) + logging.Log.WriteErrorf("Failed to load key pair: %s\n", err.Error()) return nil, err } @@ -53,10 +55,7 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, Certificates: []tls.Certificate{cert}, } - jwtManager := auth.NewJwtManager(params.Conf.Secret, 24*time.Hour) - server := grpc.NewServer( - grpc.UnaryInterceptor(jwtManager.GetAuthInterceptor()), grpc.Creds(credentials.NewTLS(serverConfig)), ) @@ -65,38 +64,51 @@ func NewConnectionServer(params *NewConnectionServerParams) (*ConnectionServer, syncProvider := params.SyncProvider connServer := ConnectionServer{ - serverConfig, - jwtManager, - server, - authProvider, - ctrlProvider, - syncProvider, - params.Conf, + serverConfig: serverConfig, + server: server, + authProvider: authProvider, + ctrlProvider: ctrlProvider, + syncProvider: syncProvider, + Conf: params.Conf, } return &connServer, nil } +// Listen for incoming requests. Returns an error if something went wrong. func (s *ConnectionServer) Listen() error { rpc.RegisterMeshCtrlServerServer(s.server, s.ctrlProvider) rpc.RegisterAuthenticationServer(s.server, s.authProvider) - logging.InfoLog.Println(s.syncProvider) rpc.RegisterSyncServiceServer(s.server, s.syncProvider) lis, err := net.Listen("tcp", ":"+s.Conf.GrpcPort) + s.listener = lis - logging.InfoLog.Printf("GRPC listening on %s\n", s.Conf.GrpcPort) + logging.Log.WriteInfof("GRPC listening on %s\n", s.Conf.GrpcPort) if err != nil { - logging.ErrorLog.Println(err.Error()) + logging.Log.WriteErrorf(err.Error()) return err } if err := s.server.Serve(lis); err != nil { - logging.ErrorLog.Println(err.Error()) + logging.Log.WriteErrorf(err.Error()) return err } return nil } + +// Close closes the connection server. Returns an error +// if something went wrong whilst attempting to close the connection +func (c *ConnectionServer) Close() error { + var err error = nil + c.server.Stop() + + if c.listener != nil { + err = c.listener.Close() + } + + return err +} diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index 0afa349..3bcfd9b 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -1,8 +1,3 @@ -/* - * ctrlserver controls the WireGuard mesh. Contains an IpcHandler for - * handling commands fired by wgmesh command. - * Contains an RpcHandler for handling commands fired by another server. - */ package ctrlserver import ( @@ -13,6 +8,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl" ) +// NewCtrlServerParams are the params requried to create a new ctrl server type NewCtrlServerParams struct { WgClient *wgctrl.Client Conf *conf.WgMeshConfiguration @@ -21,32 +17,27 @@ type NewCtrlServerParams struct { SyncProvider rpc.SyncServiceServer } -/* - * NewCtrlServer creates a new instance of the ctrlserver. - * It is associated with a WireGuard client and an interface. - * wgClient: Represents the WireGuard control client. - * ifName: WireGuard interface name - */ +// Create a new instance of the MeshCtrlServer or error if the +// operation failed func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer := new(MeshCtrlServer) ctrlServer.Client = params.WgClient ctrlServer.MeshManager = mesh.NewMeshManager(*params.WgClient, *params.Conf) ctrlServer.Conf = params.Conf - connManagerParams := conn.NewJwtConnectionManagerParams{ + connManagerParams := conn.NewConnectionManageParams{ CertificatePath: params.Conf.CertificatePath, PrivateKey: params.Conf.PrivateKeyPath, SkipCertVerification: params.Conf.SkipCertVerification, } - connMgr, err := conn.NewJwtConnectionManager(&connManagerParams) + connMgr, err := conn.NewConnectionManager(&connManagerParams) if err != nil { return nil, err } ctrlServer.ConnectionManager = connMgr - connServerParams := conn.NewConnectionServerParams{ Conf: params.Conf, AuthProvider: params.AuthProvider, @@ -63,3 +54,16 @@ func NewCtrlServer(params *NewCtrlServerParams) (*MeshCtrlServer, error) { ctrlServer.ConnectionServer = connServer return ctrlServer, nil } + +// Close closes the ctrl server tearing down any connections that exist +func (s *MeshCtrlServer) Close() error { + if err := s.ConnectionManager.Close(); err != nil { + return err + } + + if err := s.ConnectionServer.Close(); err != nil { + return err + } + + return nil +} diff --git a/pkg/log/log.go b/pkg/log/log.go index 56310ae..2a07ddb 100644 --- a/pkg/log/log.go +++ b/pkg/log/log.go @@ -1,22 +1,51 @@ +// Provides a generic interface for logging package logging -/* - * This package creates the info, warning and error loggers. - */ - import ( - "log" "os" + + "github.com/sirupsen/logrus" ) var ( - InfoLog *log.Logger - WarningLog *log.Logger - ErrorLog *log.Logger + Log Logger ) -func init() { - InfoLog = log.New(os.Stdout, "INFO: ", log.Ldate|log.Ltime|log.Lshortfile) - WarningLog = log.New(os.Stdout, "WARNING: ", log.Ldate|log.Ltime|log.Lshortfile) - ErrorLog = log.New(os.Stderr, "ERROR: ", log.Ldate|log.Ltime|log.Lshortfile) +type Logger interface { + WriteInfof(msg string, args ...interface{}) + WriteErrorf(msg string, args ...interface{}) + WriteWarnf(msg string, args ...interface{}) +} + +type LogrusLogger struct { + logger *logrus.Logger +} + +func (l *LogrusLogger) WriteInfof(msg string, args ...interface{}) { + l.logger.Infof(msg, args...) +} + +func (l *LogrusLogger) WriteErrorf(msg string, args ...interface{}) { + l.logger.Errorf(msg, args...) +} + +func (l *LogrusLogger) WriteWarnf(msg string, args ...interface{}) { + l.logger.Warnf(msg, args...) +} + +func NewLogrusLogger() *LogrusLogger { + logger := logrus.New() + logger.SetFormatter(&logrus.TextFormatter{FullTimestamp: true}) + logger.SetOutput(os.Stdout) + logger.SetLevel(logrus.InfoLevel) + + return &LogrusLogger{logger: logger} +} + +func init() { + SetLogger(NewLogrusLogger()) +} + +func SetLogger(l Logger) { + Log = l } diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go index 00d7274..388b507 100644 --- a/pkg/middleware/auth.go +++ b/pkg/middleware/auth.go @@ -4,16 +4,17 @@ import ( "context" "errors" - "github.com/tim-beatham/wgmesh/pkg/auth" logging "github.com/tim-beatham/wgmesh/pkg/log" "github.com/tim-beatham/wgmesh/pkg/rpc" ) +// AuthRpcProvider implements the AuthRpcProvider service type AuthRpcProvider struct { rpc.UnimplementedAuthenticationServer - Manager *auth.JwtManager } +// JoinMesh handles a JoinMeshRequest. Succeeds by stating the node managed to join the mesh +// or returns an error if it failed func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequest) (*rpc.JoinAuthMeshReply, error) { meshId := in.MeshId @@ -21,12 +22,8 @@ func (a *AuthRpcProvider) JoinMesh(ctx context.Context, in *rpc.JoinAuthMeshRequ return nil, errors.New("Must specify the meshId") } - logging.InfoLog.Println("MeshID: " + in.MeshId) - token, err := a.Manager.CreateClaims(in.MeshId, in.Alias) + logging.Log.WriteInfof("MeshID: " + in.MeshId) - if err != nil { - return nil, err - } - - return &rpc.JoinAuthMeshReply{Success: true, Token: token}, nil + var token string = "" + return &rpc.JoinAuthMeshReply{Success: true, Token: &token}, nil } diff --git a/pkg/robin/robin_requester.go b/pkg/robin/robin_requester.go index 2296468..8cc070c 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/robin_requester.go @@ -75,71 +75,9 @@ func (n *RobinIpc) ListMeshes(_ string, reply *ipc.ListMeshReply) error { return nil } -func (n *RobinIpc) Authenticate(meshId, endpoint string) error { - peerConnection, err := n.Server.ConnectionManager.AddConnection(endpoint) - - if err != nil { - return err - } - - err = peerConnection.Authenticate(meshId) - - if err != nil { - return err - } - - return err -} - -func (n *RobinIpc) authenticatePeers(meshId string) error { - theMesh := n.Server.MeshManager.GetMesh(meshId) - - if theMesh == nil { - return errors.New("the mesh does not exist") - } - - snapshot, _ := theMesh.GetCrdt() - publicKey, err := n.Server.MeshManager.GetPublicKey(meshId) - - if err != nil { - return err - } - - for nodeKey, node := range snapshot.Nodes { - logging.InfoLog.Println(nodeKey) - if nodeKey == publicKey.String() { - continue - } - - err := n.Authenticate(meshId, node.HostEndpoint) - - if err != nil { - return err - } - } - - return nil -} - func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { - err := n.Authenticate(args.MeshId, args.IpAdress) - - if err != nil { - return err - } - peerConnection, err := n.Server.ConnectionManager.GetConnection(args.IpAdress) - if err != nil { - return err - } - - err = peerConnection.Connect() - - if err != nil { - return err - } - client, err := peerConnection.GetClient() if err != nil { @@ -148,13 +86,11 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { c := rpc.NewMeshCtrlServerClient(client) - authContext, err := peerConnection.CreateAuthContext(args.MeshId) - if err != nil { return err } - ctx, cancel := context.WithTimeout(authContext, time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() meshReply, err := c.GetMesh(ctx, &rpc.GetMeshRequest{MeshId: args.MeshId}) @@ -181,7 +117,7 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { return err } - logging.InfoLog.Println("WgIP: " + ipAddr.String()) + logging.Log.WriteInfof("WgIP: " + ipAddr.String()) outBoundIP := lib.GetOutboundIP() @@ -206,10 +142,6 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { return err } - if joinReply.GetSuccess() { - err = n.authenticatePeers(args.MeshId) - } - if err != nil { return err } diff --git a/pkg/robin/robin_responder.go b/pkg/robin/robin_responder.go index 8d5d4df..2f77838 100644 --- a/pkg/robin/robin_responder.go +++ b/pkg/robin/robin_responder.go @@ -56,7 +56,7 @@ func (m *RobinRpc) GetMesh(ctx context.Context, request *rpc.GetMeshRequest) (*r func (m *RobinRpc) JoinMesh(ctx context.Context, request *rpc.JoinMeshRequest) (*rpc.JoinMeshReply, error) { mesh := m.Server.MeshManager.GetMesh(request.MeshId) - logging.InfoLog.Println("[JOINING MESH]: " + request.MeshId) + logging.Log.WriteInfof("[JOINING MESH]: " + request.MeshId) if mesh == nil { return nil, errors.New("mesh does not exist") diff --git a/pkg/sync/syncererror.go b/pkg/sync/syncererror.go index 3b4cc6e..4cbd655 100644 --- a/pkg/sync/syncererror.go +++ b/pkg/sync/syncererror.go @@ -34,7 +34,7 @@ func (s *SyncErrorHandlerImpl) incrementFailedCount(meshId string, endpoint stri func (s *SyncErrorHandlerImpl) Handle(meshId string, endpoint string, err error) bool { errStatus, _ := status.FromError(err) - logging.WarningLog.Printf("Handled gRPC error: %s", errStatus.Message()) + logging.Log.WriteInfof("Handled gRPC error: %s", errStatus.Message()) switch errStatus.Code() { case codes.Unavailable, codes.Unknown, codes.DeadlineExceeded, codes.Internal, codes.NotFound: diff --git a/pkg/sync/syncrequester.go b/pkg/sync/syncrequester.go index edc1ea0..b9c05be 100644 --- a/pkg/sync/syncrequester.go +++ b/pkg/sync/syncrequester.go @@ -23,22 +23,6 @@ type SyncRequesterImpl struct { errorHdlr SyncErrorHandler } -func (s *SyncRequesterImpl) Authenticate(meshId, endpoint string) error { - peerConnection, err := s.server.ConnectionManager.AddConnection(endpoint) - - if err != nil { - return err - } - - err = peerConnection.Authenticate(meshId) - - if err != nil { - return err - } - - return err -} - // GetMesh: Retrieves the local state of the mesh at the endpoint func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { peerConnection, err := s.server.ConnectionManager.GetConnection(endPoint) @@ -47,12 +31,6 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { return err } - err = peerConnection.Connect() - - if err != nil { - return err - } - client, err := peerConnection.GetClient() if err != nil { @@ -60,13 +38,8 @@ func (s *SyncRequesterImpl) GetMesh(meshId string, endPoint string) error { } c := rpc.NewSyncServiceClient(client) - authContext, err := peerConnection.CreateAuthContext(meshId) - if err != nil { - return err - } - - ctx, cancel := context.WithTimeout(authContext, time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() reply, err := c.GetConf(ctx, &rpc.GetConfRequest{MeshId: meshId}) @@ -91,34 +64,18 @@ func (s *SyncRequesterImpl) handleErr(meshId, endpoint string, err error) error // SyncMesh: Proactively send a sync request to the other mesh func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { - if !s.server.ConnectionManager.HasConnection(endpoint) { - s.Authenticate(meshId, endpoint) - } - peerConnection, err := s.server.ConnectionManager.GetConnection(endpoint) if err != nil { return err } - err = peerConnection.Connect() - - if err != nil { - return s.handleErr(meshId, endpoint, err) - } - client, err := peerConnection.GetClient() if err != nil { return err } - authContext, err := peerConnection.CreateAuthContext(meshId) - - if err != nil { - return err - } - mesh := s.server.MeshManager.GetMesh(meshId) if mesh == nil { @@ -127,7 +84,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { c := rpc.NewSyncServiceClient(client) - ctx, cancel := context.WithTimeout(authContext, 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err = syncMesh(mesh, ctx, c) @@ -136,7 +93,7 @@ func (s *SyncRequesterImpl) SyncMesh(meshId, endpoint string) error { return s.handleErr(meshId, endpoint, err) } - logging.InfoLog.Printf("Synced with node: %s meshId: %s\n", endpoint, meshId) + logging.Log.WriteInfof("Synced with node: %s meshId: %s\n", endpoint, meshId) mesh.DecrementFailedCount(endpoint) return nil } @@ -162,7 +119,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe in, err := stream.Recv() if err != nil && err != io.EOF { - logging.ErrorLog.Printf("Stream recv error: %s\n", err.Error()) + logging.Log.WriteInfof("Stream recv error: %s\n", err.Error()) return err } @@ -171,7 +128,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe } if err != nil { - logging.ErrorLog.Printf("Syncer recv error: %s\n", err.Error()) + logging.Log.WriteInfof("Syncer recv error: %s\n", err.Error()) return err } @@ -180,7 +137,7 @@ func syncMesh(mesh *crdt.CrdtNodeManager, ctx context.Context, client rpc.SyncSe } } - logging.InfoLog.Println("SYNC finished") + logging.Log.WriteInfof("SYNC finished") stream.CloseSend() return nil } diff --git a/pkg/sync/syncscheduler.go b/pkg/sync/syncscheduler.go index df09eda..9e81b78 100644 --- a/pkg/sync/syncscheduler.go +++ b/pkg/sync/syncscheduler.go @@ -34,7 +34,7 @@ func (s *SyncSchedulerImpl) Run() error { err := s.syncer.SyncMeshes() if err != nil { - logging.ErrorLog.Println(err.Error()) + logging.Log.WriteErrorf(err.Error()) } break case <-quit: diff --git a/pkg/sync/syncservice.go b/pkg/sync/syncservice.go index e4e72bf..9a4c0cb 100644 --- a/pkg/sync/syncservice.go +++ b/pkg/sync/syncservice.go @@ -41,9 +41,9 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error var syncer *crdt.AutomergeSync = nil for { - logging.InfoLog.Println("Received Attempt") + logging.Log.WriteInfof("Received Attempt") in, err := stream.Recv() - logging.InfoLog.Println("Received Worked") + logging.Log.WriteInfof("Received Worked") if err == io.EOF { return nil @@ -84,7 +84,6 @@ func (s *SyncServiceImpl) SyncMesh(stream rpc.SyncService_SyncMeshServer) error } if !moreMessages || err == io.EOF { - logging.InfoLog.Println("SYNC Completed") return nil } } diff --git a/pkg/wg/wg.go b/pkg/wg/wg.go index 51cc002..4d326c6 100644 --- a/pkg/wg/wg.go +++ b/pkg/wg/wg.go @@ -62,7 +62,7 @@ func EnableInterface(ifName string, ip string) error { cmd := exec.Command("/usr/bin/ip", "link", "set", "up", "dev", ifName) if err := cmd.Run(); err != nil { - logging.ErrorLog.Println(err.Error()) + logging.Log.WriteErrorf(err.Error()) return err }