smegmesh/pkg/auth/jwt.go

141 lines
3.2 KiB
Go

package auth
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
logging "github.com/tim-beatham/wgmesh/pkg/log"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// JwtMesh contains all the sessions with the mesh network
type JwtMesh struct {
meshId string
// nodes contains a set of nodes with the string being the jwt token
nodes map[string]interface{}
}
// JwtManager manages jwt tokens indicating a session
// between this host and another within a specific mesh
type JwtManager struct {
secretKey []byte
tokenDuration time.Duration
// meshes contains all the meshes that we have sessions with
meshes map[string]*JwtMesh
}
// JwtNode represents a jwt node in the mesh network
type JwtNode struct {
MeshId string `json:"meshId"`
Alias string `json:"alias"`
jwt.RegisteredClaims
}
func NewJwtManager(secretKey string, tokenDuration time.Duration) *JwtManager {
meshes := make(map[string]*JwtMesh)
return &JwtManager{[]byte(secretKey), tokenDuration, meshes}
}
func (m *JwtManager) CreateClaims(meshId string, alias string) (*string, error) {
logging.InfoLog.Println("MeshID: " + meshId)
logging.InfoLog.Println("Token Duration: " + strconv.Itoa(int(m.tokenDuration)))
node := JwtNode{
MeshId: meshId,
Alias: alias,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(m.tokenDuration)),
},
}
mesh, contains := m.meshes[meshId]
if !contains {
mesh = new(JwtMesh)
mesh.meshId = meshId
mesh.nodes = make(map[string]interface{})
mesh.nodes[meshId] = mesh
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, node)
signedString, err := token.SignedString(m.secretKey)
if err != nil {
fmt.Println(err.Error())
return nil, err
}
_, exists := mesh.nodes[signedString]
if exists {
return nil, errors.New("Node already exists")
}
mesh.nodes[signedString] = struct{}{}
return &signedString, nil
}
func (m *JwtManager) Verify(accessToken string) (*JwtNode, bool) {
token, err := jwt.ParseWithClaims(accessToken, &JwtNode{}, func(t *jwt.Token) (interface{}, error) {
return m.secretKey, nil
})
if err != nil {
return nil, false
}
if !token.Valid {
return nil, token.Valid
}
claims, ok := token.Claims.(*JwtNode)
return claims, ok
}
func (m *JwtManager) GetAuthInterceptor() grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req interface{},
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
if strings.Contains(info.FullMethod, "Auth") {
return handler(ctx, req)
}
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "metadata is not provided")
}
values := md["authorization"]
for _, w := range values {
logging.InfoLog.Printf(w)
}
if len(values) == 0 {
return nil, status.Errorf(codes.Unauthenticated, "authorization token is not provided")
}
acessToken := values[0]
_, valid := m.Verify(acessToken)
if !valid {
return nil, status.Errorf(codes.Unauthenticated, "Invalid access token: %s", acessToken)
}
return handler(ctx, req)
}
}