diff --git a/pkg/auth/jwt.go b/pkg/auth/jwt.go index 15ad9f4..6fddc9c 100644 --- a/pkg/auth/jwt.go +++ b/pkg/auth/jwt.go @@ -3,9 +3,12 @@ package auth import ( "context" "errors" + "fmt" + "strings" "time" "github.com/golang-jwt/jwt/v5" + logging "github.com/tim-beatham/wgmesh/pkg/log" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -22,7 +25,7 @@ type JwtMesh struct { // JwtManager manages jwt tokens indicating a session // between this host and another within a specific mesh type JwtManager struct { - secretKey string + secretKey []byte tokenDuration time.Duration // meshes contains all the meshes that we have sessions with meshes map[string]*JwtMesh @@ -37,7 +40,7 @@ type JwtNode struct { func NewJwtManager(secretKey string, tokenDuration time.Duration) *JwtManager { meshes := make(map[string]*JwtMesh) - return &JwtManager{secretKey, tokenDuration, meshes} + return &JwtManager{[]byte(secretKey), tokenDuration, meshes} } func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) { @@ -52,13 +55,17 @@ func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) mesh, contains := m.meshes[meshId] if !contains { - return nil, errors.New("The specified mesh does not exist") + mesh = new(JwtMesh) + mesh.meshId = meshId + mesh.nodes = make(map[string]interface{}) + mesh.nodes[meshId] = mesh } - token := jwt.NewWithClaims(jwt.SigningMethodES256, node) - signedString, err := token.SignedString([]byte(m.secretKey)) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, node) + signedString, err := token.SignedString(m.secretKey) if err != nil { + fmt.Println(err.Error()) return nil, err } @@ -74,7 +81,7 @@ func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) func (m *JwtManager) Verify(accessToken string) (*JwtNode, bool) { token, err := jwt.ParseWithClaims(accessToken, &JwtNode{}, func(t *jwt.Token) (interface{}, error) { - return []byte(m.secretKey), nil + return m.secretKey, nil }) if err != nil { @@ -96,6 +103,11 @@ func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor { info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (interface{}, error) { + + if strings.Contains(info.FullMethod, "Auth") { + return handler(ctx, req) + } + md, ok := metadata.FromIncomingContext(ctx) if !ok { @@ -104,6 +116,10 @@ func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor { values := md["authorization"] + for _, w := range values { + logging.InfoLog.Printf(w) + } + if len(values) == 0 { return nil, status.Errorf(codes.Unauthenticated, "authorization token is not provided") } diff --git a/pkg/auth/token.go b/pkg/auth/token.go new file mode 100644 index 0000000..7ddfe43 --- /dev/null +++ b/pkg/auth/token.go @@ -0,0 +1,51 @@ +package auth + +import ( + "errors" + + logging "github.com/tim-beatham/wgmesh/pkg/log" +) + +type TokenMesh struct { + Tokens map[string]string +} + +type TokenManager struct { + Meshes map[string]*TokenMesh +} + +func (m *TokenManager) AddToken(meshId, endpoint, token string) error { + mesh, ok := m.Meshes[endpoint] + + if !ok { + mesh = new(TokenMesh) + mesh.Tokens = make(map[string]string) + m.Meshes[endpoint] = mesh + } + + mesh.Tokens[meshId] = token + return nil +} + +func (m *TokenManager) GetToken(meshId, endpoint string) (string, error) { + mesh, ok := m.Meshes[endpoint] + + if !ok { + logging.ErrorLog.Printf("Endpoint doesnot exist: %s\n", endpoint) + return "", errors.New("Endpoint does not exist in the token manager") + } + + token, ok := mesh.Tokens[meshId] + + if !ok { + return "", errors.New("MeshId does not exist") + } + + return token, nil +} + +func NewTokenManager() *TokenManager { + var manager *TokenManager = new(TokenManager) + manager.Meshes = make(map[string]*TokenMesh) + return manager +} diff --git a/pkg/ctrlserver/ctrlserver.go b/pkg/ctrlserver/ctrlserver.go index a9fa360..9ab4aa9 100644 --- a/pkg/ctrlserver/ctrlserver.go +++ b/pkg/ctrlserver/ctrlserver.go @@ -6,6 +6,7 @@ package ctrlserver import ( + "context" "errors" "net" "time" @@ -16,6 +17,7 @@ import ( "github.com/tim-beatham/wgmesh/pkg/wg" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc/metadata" ) /* @@ -31,6 +33,7 @@ func NewCtrlServer(wgClient *wgctrl.Client, conn *conn.WgCtrlConnection, ifName ctrlServer.Conn = conn ctrlServer.IfName = ifName ctrlServer.JwtManager = auth.NewJwtManager("bob123", 24*time.Hour) + ctrlServer.TokenManager = auth.NewTokenManager() return ctrlServer } @@ -192,3 +195,13 @@ func (s *MeshCtrlServer) EnableInterface(meshId string) error { return wg.EnableInterface(s.IfName, node.WgHost) } + +func (s *MeshCtrlServer) AddToken(ctx context.Context, endpoint, meshId string) (context.Context, error) { + token, err := s.TokenManager.GetToken(meshId, endpoint) + + if err != nil { + return nil, err + } + + return metadata.AppendToOutgoingContext(ctx, "authorization", token), nil +} diff --git a/pkg/ctrlserver/ctrltypes.go b/pkg/ctrlserver/ctrltypes.go index ec63c28..a2a8a85 100644 --- a/pkg/ctrlserver/ctrltypes.go +++ b/pkg/ctrlserver/ctrltypes.go @@ -32,4 +32,5 @@ type MeshCtrlServer struct { IfName string Conn *conn.WgCtrlConnection JwtManager *auth.JwtManager + TokenManager *auth.TokenManager } diff --git a/pkg/robin/robin_requester.go b/pkg/robin/robin_requester.go index d24c086..6a188fa 100644 --- a/pkg/robin/robin_requester.go +++ b/pkg/robin/robin_requester.go @@ -64,7 +64,12 @@ func updateMesh(n *RobinIpc, meshId string, endPoint string) error { defer conn.Close() c := rpc.NewMeshCtrlServerClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, err := n.Server.AddToken(context.Background(), endPoint, meshId) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() getMeshReq := rpc.GetMeshRequest{ @@ -103,6 +108,18 @@ func updateMesh(n *RobinIpc, meshId string, endPoint string) error { } func updatePeer(n *RobinIpc, node ctrlserver.MeshNode, wgHost string, meshId string) error { + token, err := n.Authenticate(meshId, node.HostEndpoint) + + if err != nil { + return err + } + + err = n.Server.TokenManager.AddToken(meshId, node.HostEndpoint, token) + + if err != nil { + return err + } + conn, err := n.Server.Conn.Connect(node.HostEndpoint) if err != nil { @@ -113,7 +130,12 @@ func updatePeer(n *RobinIpc, node ctrlserver.MeshNode, wgHost string, meshId str c := rpc.NewMeshCtrlServerClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, err := n.Server.AddToken(context.Background(), node.HostEndpoint, meshId) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() dev := n.Server.GetDevice() @@ -156,7 +178,45 @@ func updatePeers(n *RobinIpc, meshId string, wgHost string, nodesToExclude []str return nil } +func (n *RobinIpc) Authenticate(meshId, endpoint string) (string, error) { + conn, err := n.Server.Conn.Connect(endpoint) + + if err != nil { + return "", err + } + + defer conn.Close() + + c := rpc.NewAuthenticationClient(conn) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + authRequest := rpc.JoinAuthMeshRequest{ + MeshId: meshId, + Alias: lib.GetOutboundIP().String(), + } + + reply, err := c.JoinMesh(ctx, &authRequest) + + if err != nil { + return "", err + } + + logging.InfoLog.Printf("Token: %s\n", *reply.Token) + + return *reply.Token, err +} + func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { + token, err := n.Authenticate(args.MeshId, args.IpAdress+":8080") + + if err != nil { + return err + } + + n.Server.TokenManager.AddToken(args.MeshId, args.IpAdress+":8080", token) + conn, err := n.Server.Conn.Connect(args.IpAdress + ":8080") if err != nil { @@ -167,7 +227,12 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error { c := rpc.NewMeshCtrlServerClient(conn) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) + ctx, err := n.Server.AddToken(context.Background(), args.IpAdress+":8080", args.MeshId) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() dev := n.Server.GetDevice() diff --git a/pkg/rpc/authentication.pb.go b/pkg/rpc/authentication.pb.go index 3cf5a10..f83598d 100644 --- a/pkg/rpc/authentication.pb.go +++ b/pkg/rpc/authentication.pb.go @@ -25,8 +25,8 @@ type JoinAuthMeshRequest struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"` - SharedSecret string `protobuf:"bytes,2,opt,name=sharedSecret,proto3" json:"sharedSecret,omitempty"` + MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"` + Alias string `protobuf:"bytes,2,opt,name=alias,proto3" json:"alias,omitempty"` } func (x *JoinAuthMeshRequest) Reset() { @@ -68,9 +68,9 @@ func (x *JoinAuthMeshRequest) GetMeshId() string { return "" } -func (x *JoinAuthMeshRequest) GetSharedSecret() string { +func (x *JoinAuthMeshRequest) GetAlias() string { if x != nil { - return x.SharedSecret + return x.Alias } return "" } @@ -136,24 +136,23 @@ var file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = []byte{ 0x0a, 0x28, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74, - 0x79, 0x70, 0x65, 0x73, 0x22, 0x51, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, + 0x79, 0x70, 0x65, 0x73, 0x22, 0x43, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d, 0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73, - 0x68, 0x49, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x68, 0x61, 0x72, 0x65, 0x64, 0x53, 0x65, 0x63, - 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x68, 0x61, 0x72, 0x65, - 0x64, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69, 0x6e, 0x41, - 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, - 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, - 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x88, 0x01, - 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a, 0x0e, 0x41, - 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x48, 0x0a, - 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70, 0x63, 0x74, - 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, - 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79, - 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, - 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72, - 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x68, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69, + 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, + 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, + 0x88, 0x01, 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a, + 0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, + 0x48, 0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70, + 0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, + 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63, + 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, + 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, + 0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/wgmesh.go b/wgmesh.go index 916be71..15eb2b1 100644 --- a/wgmesh.go +++ b/wgmesh.go @@ -35,7 +35,7 @@ func main() { devices, err := client.Devices() if err != nil { - return + return } fmt.Printf("Number of devices: %d\n", len(devices))