JWT Authentication endpoint

This commit is contained in:
Tim Beatham 2023-10-01 20:14:09 +01:00
parent 94afd68460
commit 52e5e3d33c
7 changed files with 175 additions and 30 deletions

View File

@ -3,9 +3,12 @@ package auth
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
@ -22,7 +25,7 @@ type JwtMesh struct {
// JwtManager manages jwt tokens indicating a session
// between this host and another within a specific mesh
type JwtManager struct {
secretKey string
secretKey []byte
tokenDuration time.Duration
// meshes contains all the meshes that we have sessions with
meshes map[string]*JwtMesh
@ -37,7 +40,7 @@ type JwtNode struct {
func NewJwtManager(secretKey string, tokenDuration time.Duration) *JwtManager {
meshes := make(map[string]*JwtMesh)
return &JwtManager{secretKey, tokenDuration, meshes}
return &JwtManager{[]byte(secretKey), tokenDuration, meshes}
}
func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) {
@ -52,13 +55,17 @@ func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error)
mesh, contains := m.meshes[meshId]
if !contains {
return nil, errors.New("The specified mesh does not exist")
mesh = new(JwtMesh)
mesh.meshId = meshId
mesh.nodes = make(map[string]interface{})
mesh.nodes[meshId] = mesh
}
token := jwt.NewWithClaims(jwt.SigningMethodES256, node)
signedString, err := token.SignedString([]byte(m.secretKey))
token := jwt.NewWithClaims(jwt.SigningMethodHS256, node)
signedString, err := token.SignedString(m.secretKey)
if err != nil {
fmt.Println(err.Error())
return nil, err
}
@ -74,7 +81,7 @@ func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error)
func (m *JwtManager) Verify(accessToken string) (*JwtNode, bool) {
token, err := jwt.ParseWithClaims(accessToken, &JwtNode{}, func(t *jwt.Token) (interface{}, error) {
return []byte(m.secretKey), nil
return m.secretKey, nil
})
if err != nil {
@ -96,6 +103,11 @@ func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor {
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if strings.Contains(info.FullMethod, "Auth") {
return handler(ctx, req)
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
@ -104,6 +116,10 @@ func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor {
values := md["authorization"]
for _, w := range values {
logging.InfoLog.Printf(w)
}
if len(values) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "authorization token is not provided")
}

51
pkg/auth/token.go Normal file
View File

@ -0,0 +1,51 @@
package auth
import (
"errors"
logging "github.com/tim-beatham/wgmesh/pkg/log"
)
type TokenMesh struct {
Tokens map[string]string
}
type TokenManager struct {
Meshes map[string]*TokenMesh
}
func (m *TokenManager) AddToken(meshId, endpoint, token string) error {
mesh, ok := m.Meshes[endpoint]
if !ok {
mesh = new(TokenMesh)
mesh.Tokens = make(map[string]string)
m.Meshes[endpoint] = mesh
}
mesh.Tokens[meshId] = token
return nil
}
func (m *TokenManager) GetToken(meshId, endpoint string) (string, error) {
mesh, ok := m.Meshes[endpoint]
if !ok {
logging.ErrorLog.Printf("Endpoint doesnot exist: %s\n", endpoint)
return "", errors.New("Endpoint does not exist in the token manager")
}
token, ok := mesh.Tokens[meshId]
if !ok {
return "", errors.New("MeshId does not exist")
}
return token, nil
}
func NewTokenManager() *TokenManager {
var manager *TokenManager = new(TokenManager)
manager.Meshes = make(map[string]*TokenMesh)
return manager
}

View File

@ -6,6 +6,7 @@
package ctrlserver
import (
"context"
"errors"
"net"
"time"
@ -16,6 +17,7 @@ import (
"github.com/tim-beatham/wgmesh/pkg/wg"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/metadata"
)
/*
@ -31,6 +33,7 @@ func NewCtrlServer(wgClient *wgctrl.Client, conn *conn.WgCtrlConnection, ifName
ctrlServer.Conn = conn
ctrlServer.IfName = ifName
ctrlServer.JwtManager = auth.NewJwtManager("bob123", 24*time.Hour)
ctrlServer.TokenManager = auth.NewTokenManager()
return ctrlServer
}
@ -192,3 +195,13 @@ func (s *MeshCtrlServer) EnableInterface(meshId string) error {
return wg.EnableInterface(s.IfName, node.WgHost)
}
func (s *MeshCtrlServer) AddToken(ctx context.Context, endpoint, meshId string) (context.Context, error) {
token, err := s.TokenManager.GetToken(meshId, endpoint)
if err != nil {
return nil, err
}
return metadata.AppendToOutgoingContext(ctx, "authorization", token), nil
}

View File

@ -32,4 +32,5 @@ type MeshCtrlServer struct {
IfName string
Conn *conn.WgCtrlConnection
JwtManager *auth.JwtManager
TokenManager *auth.TokenManager
}

View File

@ -64,7 +64,12 @@ func updateMesh(n *RobinIpc, meshId string, endPoint string) error {
defer conn.Close()
c := rpc.NewMeshCtrlServerClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, err := n.Server.AddToken(context.Background(), endPoint, meshId)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
getMeshReq := rpc.GetMeshRequest{
@ -103,6 +108,18 @@ func updateMesh(n *RobinIpc, meshId string, endPoint string) error {
}
func updatePeer(n *RobinIpc, node ctrlserver.MeshNode, wgHost string, meshId string) error {
token, err := n.Authenticate(meshId, node.HostEndpoint)
if err != nil {
return err
}
err = n.Server.TokenManager.AddToken(meshId, node.HostEndpoint, token)
if err != nil {
return err
}
conn, err := n.Server.Conn.Connect(node.HostEndpoint)
if err != nil {
@ -113,7 +130,12 @@ func updatePeer(n *RobinIpc, node ctrlserver.MeshNode, wgHost string, meshId str
c := rpc.NewMeshCtrlServerClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, err := n.Server.AddToken(context.Background(), node.HostEndpoint, meshId)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
dev := n.Server.GetDevice()
@ -156,7 +178,45 @@ func updatePeers(n *RobinIpc, meshId string, wgHost string, nodesToExclude []str
return nil
}
func (n *RobinIpc) Authenticate(meshId, endpoint string) (string, error) {
conn, err := n.Server.Conn.Connect(endpoint)
if err != nil {
return "", err
}
defer conn.Close()
c := rpc.NewAuthenticationClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
authRequest := rpc.JoinAuthMeshRequest{
MeshId: meshId,
Alias: lib.GetOutboundIP().String(),
}
reply, err := c.JoinMesh(ctx, &authRequest)
if err != nil {
return "", err
}
logging.InfoLog.Printf("Token: %s\n", *reply.Token)
return *reply.Token, err
}
func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
token, err := n.Authenticate(args.MeshId, args.IpAdress+":8080")
if err != nil {
return err
}
n.Server.TokenManager.AddToken(args.MeshId, args.IpAdress+":8080", token)
conn, err := n.Server.Conn.Connect(args.IpAdress + ":8080")
if err != nil {
@ -167,7 +227,12 @@ func (n *RobinIpc) JoinMesh(args ipc.JoinMeshArgs, reply *string) error {
c := rpc.NewMeshCtrlServerClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, err := n.Server.AddToken(context.Background(), args.IpAdress+":8080", args.MeshId)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
dev := n.Server.GetDevice()

View File

@ -25,8 +25,8 @@ type JoinAuthMeshRequest struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
SharedSecret string `protobuf:"bytes,2,opt,name=sharedSecret,proto3" json:"sharedSecret,omitempty"`
MeshId string `protobuf:"bytes,1,opt,name=meshId,proto3" json:"meshId,omitempty"`
Alias string `protobuf:"bytes,2,opt,name=alias,proto3" json:"alias,omitempty"`
}
func (x *JoinAuthMeshRequest) Reset() {
@ -68,9 +68,9 @@ func (x *JoinAuthMeshRequest) GetMeshId() string {
return ""
}
func (x *JoinAuthMeshRequest) GetSharedSecret() string {
func (x *JoinAuthMeshRequest) GetAlias() string {
if x != nil {
return x.SharedSecret
return x.Alias
}
return ""
}
@ -136,24 +136,23 @@ var file_pkg_grpc_ctrlserver_authentication_proto_rawDesc = []byte{
0x0a, 0x28, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x63, 0x74, 0x72, 0x6c, 0x73,
0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61,
0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x08, 0x72, 0x70, 0x63, 0x74,
0x79, 0x70, 0x65, 0x73, 0x22, 0x51, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68,
0x79, 0x70, 0x65, 0x73, 0x22, 0x43, 0x0a, 0x13, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68,
0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6d,
0x65, 0x73, 0x68, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6d, 0x65, 0x73,
0x68, 0x49, 0x64, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x68, 0x61, 0x72, 0x65, 0x64, 0x53, 0x65, 0x63,
0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x68, 0x61, 0x72, 0x65,
0x64, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69, 0x6e, 0x41,
0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07,
0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73,
0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x18,
0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x88, 0x01,
0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a, 0x0e, 0x41,
0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x48, 0x0a,
0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70, 0x63, 0x74,
0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73,
0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63, 0x74, 0x79,
0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68,
0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67, 0x2f, 0x72,
0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x68, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x18, 0x02, 0x20, 0x01,
0x28, 0x09, 0x52, 0x05, 0x61, 0x6c, 0x69, 0x61, 0x73, 0x22, 0x52, 0x0a, 0x11, 0x4a, 0x6f, 0x69,
0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18,
0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x12, 0x19, 0x0a, 0x05, 0x74, 0x6f, 0x6b, 0x65,
0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e,
0x88, 0x01, 0x01, 0x42, 0x08, 0x0a, 0x06, 0x5f, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x32, 0x5a, 0x0a,
0x0e, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12,
0x48, 0x0a, 0x08, 0x4a, 0x6f, 0x69, 0x6e, 0x4d, 0x65, 0x73, 0x68, 0x12, 0x1d, 0x2e, 0x72, 0x70,
0x63, 0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d,
0x65, 0x73, 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x72, 0x70, 0x63,
0x74, 0x79, 0x70, 0x65, 0x73, 0x2e, 0x4a, 0x6f, 0x69, 0x6e, 0x41, 0x75, 0x74, 0x68, 0x4d, 0x65,
0x73, 0x68, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x42, 0x09, 0x5a, 0x07, 0x70, 0x6b, 0x67,
0x2f, 0x72, 0x70, 0x63, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (

View File

@ -35,7 +35,7 @@ func main() {
devices, err := client.Devices()
if err != nil {
return
return
}
fmt.Printf("Number of devices: %d\n", len(devices))