package sessions import ( "encoding/base64" "encoding/json" "net/http" "net/url" "testing" "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" xcfg "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/cfg" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web/cookie" ) func TestSessions(t *testing.T) { srv := test.NewOIDCServer(t) defer srv.Close() clientID := srv.MockOIDC.Config().ClientID clientSecret := srv.MockOIDC.Config().ClientSecret cfg, err := oidc.New( t.Context(), srv.Issuer(), clientID, clientSecret, "callback_url", ) require.NoError(t, err) controller := New(cfg, http.DefaultClient) mux := http.NewServeMux() controller.MountTo(mux) t.Run("GET /session/new", func(t *testing.T) { t.Run("without an authenticated session", func(t *testing.T) { r, w := test.RequestResponse("GET", "/session/new") mux.ServeHTTP(w, r) t.Run("redirect to the OIDC Provider", func(t *testing.T) { require.Equal(t, http.StatusFound, w.Code) require.NotEmpty(t, w.Header().Get("Location")) redirectURL, err := url.Parse(w.Header().Get("Location")) require.NoError(t, err) assert.Equal(t, srv.AuthorizationEndpoint(), redirectURL.Scheme+"://"+redirectURL.Host+redirectURL.Path) assert.NotEmpty(t, redirectURL.Query().Get("state")) assert.Equal(t, srv.MockOIDC.Config().ClientID, redirectURL.Query().Get("client_id")) assert.Equal(t, "openid profile email", redirectURL.Query().Get("scope")) assert.Equal(t, cfg.Config.ClientID, redirectURL.Query().Get("audience")) assert.Equal(t, cfg.Config.RedirectURL, redirectURL.Query().Get("redirect_uri")) assert.Equal(t, "code", redirectURL.Query().Get("response_type")) }) t.Run("generates a CSRF token", func(t *testing.T) { cookieHeader := w.Header().Get("Set-Cookie") require.NotEmpty(t, cookieHeader) cookie, err := http.ParseSetCookie(w.Header().Get("Set-Cookie")) require.NoError(t, err) require.NotZero(t, cookie) }) }) t.Run("with an active authenicated session", func(t *testing.T) { t.Run("redirects to the dashboard", func(t *testing.T) { user := &domain.User{} r, w := test.RequestResponse( "GET", "/session/new", test.WithContextKeyValue(t.Context(), xcfg.CurrentUser, user), ) mux.ServeHTTP(w, r) require.Equal(t, http.StatusFound, w.Code) assert.Equal(t, "/dashboard", w.Header().Get("Location")) }) }) }) t.Run("GET /session/callback", func(t *testing.T) { t.Run("with an invalid csrf token", func(t *testing.T) { user := mockoidc.DefaultUser() code := srv.CreateAuthorizationCodeFor(user) nonce := pls.GenerateRandomHex(32) r, w := test.RequestResponse( "GET", "/session/callback?code="+code+"&state=invalid", test.WithCookie(cookie.New("oauth_state", nonce)), ) mux.ServeHTTP(w, r) require.Equal(t, http.StatusBadRequest, w.Code) }) t.Run("with an invalid authorization code grant", func(t *testing.T) { r, w := test.RequestResponse("GET", "/session/callback?code=invalid") mux.ServeHTTP(w, r) assert.Equal(t, http.StatusBadRequest, w.Code) }) t.Run("with a valid authorization code grant", func(t *testing.T) { user := mockoidc.DefaultUser() code := srv.CreateAuthorizationCodeFor(user) nonce := pls.GenerateRandomHex(32) r, w := test.RequestResponse( "GET", "/session/callback?code="+code+"&state="+nonce, test.WithCookie(cookie.New("oauth_state", nonce)), ) mux.ServeHTTP(w, r) cookie, err := http.ParseSetCookie(w.Header().Get("Set-Cookie")) require.NoError(t, err) require.NotZero(t, cookie) data, err := base64.URLEncoding.DecodeString(cookie.Value) require.NoError(t, err) tokens := map[string]interface{}{} require.NoError(t, json.Unmarshal(data, &tokens)) t.Run("stores the id token in a session cookie", func(t *testing.T) { require.NotEmpty(t, tokens["id_token"]) idToken := srv.Verify(tokens["id_token"].(string)) assert.Equal(t, user.Subject, idToken.Subject) }) t.Run("stores the access token in a session cookie", func(t *testing.T) { assert.NotEmpty(t, tokens["access_token"]) assert.Equal(t, "bearer", tokens["token_type"]) keypair, err := mockoidc.DefaultKeypair() require.NoError(t, err) token, err := keypair.VerifyJWT(tokens["access_token"].(string), nil) require.NoError(t, err) sub, err := token.Claims.GetSubject() require.NoError(t, err) assert.Equal(t, user.Subject, sub) }) t.Run("stores the refresh token in a session cookie", func(t *testing.T) { assert.NotEmpty(t, tokens["refresh_token"]) keypair, err := mockoidc.DefaultKeypair() require.NoError(t, err) token, err := keypair.VerifyJWT(tokens["refresh_token"].(string), nil) require.NoError(t, err) sub, err := token.Claims.GetSubject() require.NoError(t, err) assert.Equal(t, user.Subject, sub) }) t.Run("redirects to the homepage", func(t *testing.T) { require.Equal(t, http.StatusFound, w.Code) assert.Equal(t, "/dashboard", w.Header().Get("Location")) }) }) }) t.Run("POST /session/destroy", func(t *testing.T) { t.Run("clears the session cookie", func(t *testing.T) { cookie := cookie.New("session", "value") r, w := test.RequestResponse("POST", "/session/destroy", test.WithCookie(cookie)) mux.ServeHTTP(w, r) require.Equal(t, http.StatusFound, w.Code) assert.Equal(t, "/", w.Header().Get("Location")) assert.Equal(t, "session=; Path=/; Domain=localhost; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Max-Age=0; HttpOnly; Secure", w.Header().Get("Set-Cookie")) }) }) }