mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-02 00:41:11 +01:00
The API authentication with PATs was not considering different userIDClaim that some of the IdPs are using. In this PR we read the userIDClaim from the config file instead of using the fixed default and only keep it as a fallback if none in defined.
125 lines
2.9 KiB
Go
125 lines
2.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt"
|
|
|
|
"github.com/netbirdio/netbird/management/server"
|
|
)
|
|
|
|
const (
|
|
audience = "audience"
|
|
userIDClaim = "userIDClaim"
|
|
accountID = "accountID"
|
|
domain = "domain"
|
|
userID = "userID"
|
|
tokenID = "tokenID"
|
|
PAT = "PAT"
|
|
JWT = "JWT"
|
|
wrongToken = "wrongToken"
|
|
)
|
|
|
|
var testAccount = &server.Account{
|
|
Id: accountID,
|
|
Domain: domain,
|
|
Users: map[string]*server.User{
|
|
userID: {
|
|
Id: userID,
|
|
PATs: map[string]*server.PersonalAccessToken{
|
|
tokenID: {
|
|
ID: tokenID,
|
|
Name: "My first token",
|
|
HashedToken: "someHash",
|
|
ExpirationDate: time.Now().UTC().AddDate(0, 0, 7),
|
|
CreatedBy: userID,
|
|
CreatedAt: time.Now().UTC(),
|
|
LastUsed: time.Now().UTC(),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
|
|
if token == PAT {
|
|
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
|
|
}
|
|
return nil, nil, nil, fmt.Errorf("PAT invalid")
|
|
}
|
|
|
|
func mockValidateAndParseToken(token string) (*jwt.Token, error) {
|
|
if token == JWT {
|
|
return &jwt.Token{}, nil
|
|
}
|
|
return nil, fmt.Errorf("JWT invalid")
|
|
}
|
|
|
|
func mockMarkPATUsed(token string) error {
|
|
if token == tokenID {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("Should never get reached")
|
|
}
|
|
|
|
func TestAuthMiddleware_Handler(t *testing.T) {
|
|
tt := []struct {
|
|
name string
|
|
authHeader string
|
|
expectedStatusCode int
|
|
}{
|
|
{
|
|
name: "Valid PAT Token",
|
|
authHeader: "Token " + PAT,
|
|
expectedStatusCode: 200,
|
|
},
|
|
{
|
|
name: "Invalid PAT Token",
|
|
authHeader: "Token " + wrongToken,
|
|
expectedStatusCode: 401,
|
|
},
|
|
{
|
|
name: "Valid JWT Token",
|
|
authHeader: "Bearer " + JWT,
|
|
expectedStatusCode: 200,
|
|
},
|
|
{
|
|
name: "Invalid JWT Token",
|
|
authHeader: "Bearer " + wrongToken,
|
|
expectedStatusCode: 401,
|
|
},
|
|
{
|
|
name: "Basic Auth",
|
|
authHeader: "Basic " + PAT,
|
|
expectedStatusCode: 401,
|
|
},
|
|
}
|
|
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// do nothing
|
|
})
|
|
|
|
authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience, userIDClaim)
|
|
|
|
handlerToTest := authMiddleware.Handler(nextHandler)
|
|
|
|
for _, tc := range tt {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "http://testing", nil)
|
|
req.Header.Set("Authorization", tc.authHeader)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handlerToTest.ServeHTTP(rec, req)
|
|
|
|
if rec.Result().StatusCode != tc.expectedStatusCode {
|
|
t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode)
|
|
}
|
|
})
|
|
}
|
|
|
|
}
|