mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-26 10:03:47 +01:00
a2fc4ec221
Add method for rotating access token with refresh tokens This will be useful for catching expired sessions and offboarding users Also added functions to handle secrets. They have to be revisited as some tests didn't run on CI as they waited some user input, like password
416 lines
12 KiB
Go
416 lines
12 KiB
Go
package oauth
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/golang-jwt/jwt"
|
|
"github.com/stretchr/testify/require"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
type mockHTTPClient struct {
|
|
code int
|
|
resBody string
|
|
reqBody string
|
|
MaxReqs int
|
|
count int
|
|
countResBody string
|
|
err error
|
|
}
|
|
|
|
func (c *mockHTTPClient) Do(req *http.Request) (*http.Response, error) {
|
|
body, err := ioutil.ReadAll(req.Body)
|
|
if err == nil {
|
|
c.reqBody = string(body)
|
|
}
|
|
|
|
if c.MaxReqs > c.count {
|
|
c.count++
|
|
return &http.Response{
|
|
StatusCode: c.code,
|
|
Body: ioutil.NopCloser(strings.NewReader(c.countResBody)),
|
|
}, c.err
|
|
}
|
|
|
|
return &http.Response{
|
|
StatusCode: c.code,
|
|
Body: ioutil.NopCloser(strings.NewReader(c.resBody)),
|
|
}, c.err
|
|
}
|
|
|
|
func TestAuth0_RequestDeviceCode(t *testing.T) {
|
|
type test struct {
|
|
name string
|
|
inputResBody string
|
|
inputReqCode int
|
|
inputReqError error
|
|
testingErrFunc require.ErrorAssertionFunc
|
|
expectedErrorMSG string
|
|
testingFunc require.ComparisonAssertionFunc
|
|
expectedOut DeviceAuthInfo
|
|
expectedMSG string
|
|
expectPayload RequestDeviceCodePayload
|
|
}
|
|
|
|
testCase1 := test{
|
|
name: "Payload Is Valid",
|
|
expectPayload: RequestDeviceCodePayload{
|
|
Audience: "ok",
|
|
ClientID: "bla",
|
|
},
|
|
inputReqCode: 200,
|
|
testingErrFunc: require.Error,
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
testCase2 := test{
|
|
name: "Exit On Network Error",
|
|
inputReqError: fmt.Errorf("error"),
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
testCase3 := test{
|
|
name: "Exit On Exit Code",
|
|
inputReqCode: 400,
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
testCase4Out := DeviceAuthInfo{ExpiresIn: 10}
|
|
testCase4 := test{
|
|
name: "Got Device Code",
|
|
inputResBody: fmt.Sprintf("{\"expires_in\":%d}", testCase4Out.ExpiresIn),
|
|
expectPayload: RequestDeviceCodePayload{
|
|
Audience: "ok",
|
|
ClientID: "bla",
|
|
},
|
|
inputReqCode: 200,
|
|
testingErrFunc: require.NoError,
|
|
testingFunc: require.EqualValues,
|
|
expectedOut: testCase4Out,
|
|
expectedMSG: "out should match",
|
|
}
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4} {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
httpClient := mockHTTPClient{
|
|
resBody: testCase.inputResBody,
|
|
code: testCase.inputReqCode,
|
|
err: testCase.inputReqError,
|
|
}
|
|
|
|
auth0 := Auth0{
|
|
Audience: testCase.expectPayload.Audience,
|
|
ClientID: testCase.expectPayload.ClientID,
|
|
Domain: "test.auth0.com",
|
|
HTTPClient: &httpClient,
|
|
}
|
|
|
|
authInfo, err := auth0.RequestDeviceCode(context.TODO())
|
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
|
|
|
payload, _ := json.Marshal(testCase.expectPayload)
|
|
|
|
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
|
|
|
testCase.testingFunc(t, testCase.expectedOut, authInfo, testCase.expectedMSG)
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuth0_WaitToken(t *testing.T) {
|
|
type test struct {
|
|
name string
|
|
inputResBody string
|
|
inputReqCode int
|
|
inputReqError error
|
|
inputMaxReqs int
|
|
inputCountResBody string
|
|
inputTimeout time.Duration
|
|
inputInfo DeviceAuthInfo
|
|
inputAudience string
|
|
testingErrFunc require.ErrorAssertionFunc
|
|
expectedErrorMSG string
|
|
testingFunc require.ComparisonAssertionFunc
|
|
expectedOut TokenInfo
|
|
expectedMSG string
|
|
expectPayload TokenRequestPayload
|
|
}
|
|
|
|
defaultInfo := DeviceAuthInfo{
|
|
DeviceCode: "test",
|
|
ExpiresIn: 10,
|
|
Interval: 1,
|
|
}
|
|
|
|
tokenReqPayload := TokenRequestPayload{
|
|
GrantType: auth0GrantType,
|
|
DeviceCode: defaultInfo.DeviceCode,
|
|
ClientID: "test",
|
|
}
|
|
|
|
testCase1 := test{
|
|
name: "Payload Is Valid",
|
|
inputInfo: defaultInfo,
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
inputReqCode: 200,
|
|
testingErrFunc: require.Error,
|
|
testingFunc: require.EqualValues,
|
|
expectPayload: tokenReqPayload,
|
|
}
|
|
|
|
testCase2 := test{
|
|
name: "Exit On Network Error",
|
|
inputInfo: defaultInfo,
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
expectPayload: tokenReqPayload,
|
|
inputReqError: fmt.Errorf("error"),
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
testCase3 := test{
|
|
name: "Exit On 4XX When Not Pending",
|
|
inputInfo: defaultInfo,
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
inputReqCode: 400,
|
|
expectPayload: tokenReqPayload,
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
testCase4 := test{
|
|
name: "Exit On Exit Code 5XX",
|
|
inputInfo: defaultInfo,
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
inputReqCode: 500,
|
|
expectPayload: tokenReqPayload,
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
testCase5 := test{
|
|
name: "Exit On Content Timeout",
|
|
inputInfo: defaultInfo,
|
|
inputTimeout: 0 * time.Second,
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
audience := "test"
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
|
|
var hmacSampleSecret []byte
|
|
tokenString, _ := token.SignedString(hmacSampleSecret)
|
|
|
|
testCase6 := test{
|
|
name: "Exit On Invalid Audience",
|
|
inputInfo: defaultInfo,
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
inputReqCode: 200,
|
|
inputAudience: "super test",
|
|
testingErrFunc: require.Error,
|
|
testingFunc: require.EqualValues,
|
|
expectPayload: tokenReqPayload,
|
|
}
|
|
|
|
testCase7 := test{
|
|
name: "Received Token Info",
|
|
inputInfo: defaultInfo,
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
inputReqCode: 200,
|
|
inputAudience: audience,
|
|
testingErrFunc: require.NoError,
|
|
testingFunc: require.EqualValues,
|
|
expectPayload: tokenReqPayload,
|
|
expectedOut: TokenInfo{AccessToken: tokenString},
|
|
}
|
|
|
|
testCase8 := test{
|
|
name: "Received Token Info after Multiple tries",
|
|
inputInfo: defaultInfo,
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
inputTimeout: time.Duration(defaultInfo.ExpiresIn) * time.Second,
|
|
inputMaxReqs: 2,
|
|
inputCountResBody: "{\"error\":\"authorization_pending\"}",
|
|
inputReqCode: 200,
|
|
inputAudience: audience,
|
|
testingErrFunc: require.NoError,
|
|
testingFunc: require.EqualValues,
|
|
expectPayload: tokenReqPayload,
|
|
expectedOut: TokenInfo{AccessToken: tokenString},
|
|
}
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6, testCase7, testCase8} {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
httpClient := mockHTTPClient{
|
|
resBody: testCase.inputResBody,
|
|
code: testCase.inputReqCode,
|
|
err: testCase.inputReqError,
|
|
MaxReqs: testCase.inputMaxReqs,
|
|
countResBody: testCase.inputCountResBody,
|
|
}
|
|
|
|
auth0 := Auth0{
|
|
Audience: testCase.inputAudience,
|
|
ClientID: testCase.expectPayload.ClientID,
|
|
Domain: "test.auth0.com",
|
|
HTTPClient: &httpClient,
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.TODO(), testCase.inputTimeout)
|
|
defer cancel()
|
|
tokenInfo, err := auth0.WaitToken(ctx, testCase.inputInfo)
|
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
|
|
|
var payload []byte
|
|
var emptyPayload TokenRequestPayload
|
|
if testCase.expectPayload != emptyPayload {
|
|
payload, _ = json.Marshal(testCase.expectPayload)
|
|
}
|
|
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
|
|
|
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
|
|
|
|
require.GreaterOrEqualf(t, testCase.inputMaxReqs, httpClient.count, "should run %d times", testCase.inputMaxReqs)
|
|
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuth0_RotateAccessToken(t *testing.T) {
|
|
type test struct {
|
|
name string
|
|
inputResBody string
|
|
inputReqCode int
|
|
inputReqError error
|
|
inputMaxReqs int
|
|
inputInfo DeviceAuthInfo
|
|
inputAudience string
|
|
testingErrFunc require.ErrorAssertionFunc
|
|
expectedErrorMSG string
|
|
testingFunc require.ComparisonAssertionFunc
|
|
expectedOut TokenInfo
|
|
expectedMSG string
|
|
expectPayload TokenRequestPayload
|
|
}
|
|
|
|
defaultInfo := DeviceAuthInfo{
|
|
DeviceCode: "test",
|
|
ExpiresIn: 10,
|
|
Interval: 1,
|
|
}
|
|
|
|
tokenReqPayload := TokenRequestPayload{
|
|
GrantType: auth0RefreshGrant,
|
|
ClientID: "test",
|
|
RefreshToken: "refresh_test",
|
|
}
|
|
|
|
testCase1 := test{
|
|
name: "Payload Is Valid",
|
|
inputInfo: defaultInfo,
|
|
inputReqCode: 200,
|
|
testingErrFunc: require.Error,
|
|
testingFunc: require.EqualValues,
|
|
expectPayload: tokenReqPayload,
|
|
}
|
|
|
|
testCase2 := test{
|
|
name: "Exit On Network Error",
|
|
inputInfo: defaultInfo,
|
|
expectPayload: tokenReqPayload,
|
|
inputReqError: fmt.Errorf("error"),
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
testCase3 := test{
|
|
name: "Exit On Non 200 Status Code",
|
|
inputInfo: defaultInfo,
|
|
inputReqCode: 401,
|
|
expectPayload: tokenReqPayload,
|
|
testingErrFunc: require.Error,
|
|
expectedErrorMSG: "should return error",
|
|
testingFunc: require.EqualValues,
|
|
}
|
|
|
|
audience := "test"
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{"aud": audience})
|
|
var hmacSampleSecret []byte
|
|
tokenString, _ := token.SignedString(hmacSampleSecret)
|
|
|
|
testCase4 := test{
|
|
name: "Exit On Invalid Audience",
|
|
inputInfo: defaultInfo,
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
inputReqCode: 200,
|
|
inputAudience: "super test",
|
|
testingErrFunc: require.Error,
|
|
testingFunc: require.EqualValues,
|
|
expectPayload: tokenReqPayload,
|
|
}
|
|
|
|
testCase5 := test{
|
|
name: "Received Token Info",
|
|
inputInfo: defaultInfo,
|
|
inputResBody: fmt.Sprintf("{\"access_token\":\"%s\"}", tokenString),
|
|
inputReqCode: 200,
|
|
inputAudience: audience,
|
|
testingErrFunc: require.NoError,
|
|
testingFunc: require.EqualValues,
|
|
expectPayload: tokenReqPayload,
|
|
expectedOut: TokenInfo{AccessToken: tokenString},
|
|
}
|
|
|
|
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
|
|
t.Run(testCase.name, func(t *testing.T) {
|
|
|
|
httpClient := mockHTTPClient{
|
|
resBody: testCase.inputResBody,
|
|
code: testCase.inputReqCode,
|
|
err: testCase.inputReqError,
|
|
MaxReqs: testCase.inputMaxReqs,
|
|
}
|
|
|
|
auth0 := Auth0{
|
|
Audience: testCase.inputAudience,
|
|
ClientID: testCase.expectPayload.ClientID,
|
|
Domain: "test.auth0.com",
|
|
HTTPClient: &httpClient,
|
|
}
|
|
|
|
tokenInfo, err := auth0.RotateAccessToken(context.TODO(), testCase.expectPayload.RefreshToken)
|
|
testCase.testingErrFunc(t, err, testCase.expectedErrorMSG)
|
|
|
|
var payload []byte
|
|
var emptyPayload TokenRequestPayload
|
|
if testCase.expectPayload != emptyPayload {
|
|
payload, _ = json.Marshal(testCase.expectPayload)
|
|
}
|
|
require.EqualValues(t, string(payload), httpClient.reqBody, "payload should match")
|
|
|
|
testCase.testingFunc(t, testCase.expectedOut, tokenInfo, testCase.expectedMSG)
|
|
|
|
})
|
|
}
|
|
}
|