From 2eaf4aa8d7b8972b95772374551466445d43d49c Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Fri, 31 Mar 2023 12:44:22 +0200 Subject: [PATCH] add test for auth middleware --- .../http/middleware/auth_middleware_test.go | 123 ++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 management/server/http/middleware/auth_middleware_test.go diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go new file mode 100644 index 000000000..e77beb21f --- /dev/null +++ b/management/server/http/middleware/auth_middleware_test.go @@ -0,0 +1,123 @@ +package middleware + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/golang-jwt/jwt" + + "github.com/netbirdio/netbird/management/server" +) + +const ( + audience = "audience" + accountID = "accountID" + domain = "domain" + userID = "userID" + tokenID = "tokenID" + PAT = "PAT" + JWT = "JWT" + wrongToken = "wrongToken" +) + +var testAccount = &server.Account{ + Id: accountID, + Domain: domain, + Users: map[string]*server.User{ + userID: { + Id: userID, + PATs: map[string]*server.PersonalAccessToken{ + tokenID: { + ID: tokenID, + Name: "My first token", + HashedToken: "someHash", + ExpirationDate: time.Now().AddDate(0, 0, 7), + CreatedBy: userID, + CreatedAt: time.Now(), + LastUsed: time.Now(), + }, + }, + }, + }, +} + +func mockGetAccountFromPAT(token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { + if token == PAT { + return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil + } + return nil, nil, nil, fmt.Errorf("PAT invalid") +} + +func mockValidateAndParseToken(token string) (*jwt.Token, error) { + if token == JWT { + return &jwt.Token{}, nil + } + return nil, fmt.Errorf("JWT invalid") +} + +func mockMarkPATUsed(token string) error { + if token == tokenID { + return nil + } + return fmt.Errorf("Should never get reached") +} + +func TestAccounts_AccountsHandler(t *testing.T) { + tt := []struct { + name string + authHeader string + expectedStatusCode int + }{ + { + name: "Valid PAT Token", + authHeader: "Token " + PAT, + expectedStatusCode: 200, + }, + { + name: "Invalid PAT Token", + authHeader: "Token " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Valid JWT Token", + authHeader: "Bearer " + JWT, + expectedStatusCode: 200, + }, + { + name: "Invalid JWT Token", + authHeader: "Bearer " + wrongToken, + expectedStatusCode: 401, + }, + { + name: "Basic Auth", + authHeader: "Basic " + PAT, + expectedStatusCode: 401, + }, + } + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // do nothing + }) + + authMiddleware := NewAuthMiddleware(mockGetAccountFromPAT, mockValidateAndParseToken, mockMarkPATUsed, audience) + + handlerToTest := authMiddleware.Handler(nextHandler) + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "http://testing", nil) + req.Header.Set("Authorization", tc.authHeader) + rec := httptest.NewRecorder() + + handlerToTest.ServeHTTP(rec, req) + + if rec.Result().StatusCode != tc.expectedStatusCode { + t.Errorf("expected status code %d, got %d", tc.expectedStatusCode, rec.Result().StatusCode) + } + }) + } + +}