diff options
| author | mo khan <mo@mokhan.ca> | 2025-04-28 17:04:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-04-28 17:04:49 -0600 |
| commit | 519400fa417fb6becb14654011ad15b9f5e5fa7c (patch) | |
| tree | ed3e6b97234260b4dfed160cba83f0a0688817a1 /app/controllers/sessions | |
| parent | 059a87a80227426f854256139bbbc7309bdb6fa0 (diff) | |
feat: validate the csrf token
Diffstat (limited to 'app/controllers/sessions')
| -rw-r--r-- | app/controllers/sessions/controller.go | 2 | ||||
| -rw-r--r-- | app/controllers/sessions/controller_test.go | 29 | ||||
| -rw-r--r-- | app/controllers/sessions/service.go | 11 |
3 files changed, 36 insertions, 6 deletions
diff --git a/app/controllers/sessions/controller.go b/app/controllers/sessions/controller.go index e2f4b22..8d0e858 100644 --- a/app/controllers/sessions/controller.go +++ b/app/controllers/sessions/controller.go @@ -121,7 +121,7 @@ func (c *Controller) Create(w http.ResponseWriter, r *http.Request) { tokens, err := c.svc.Exchange(r) if err != nil { log.WithFields(r.Context(), log.Fields{"error": err}) - w.WriteHeader(http.StatusInternalServerError) + w.WriteHeader(http.StatusBadRequest) return } diff --git a/app/controllers/sessions/controller_test.go b/app/controllers/sessions/controller_test.go index 05f642b..c0c1de2 100644 --- a/app/controllers/sessions/controller_test.go +++ b/app/controllers/sessions/controller_test.go @@ -15,6 +15,7 @@ import ( xcfg "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/cfg" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web/cookie" ) @@ -85,25 +86,43 @@ func TestSessions(t *testing.T) { assert.Equal(t, "/dashboard", w.Header().Get("Location")) }) }) - - t.Run("with an expired authenicated session", func(t *testing.T) {}) }) t.Run("GET /session/callback", func(t *testing.T) { - t.Run("with an invalid csrf token", func(t *testing.T) {}) + t.Run("with an invalid csrf token", func(t *testing.T) { + user := mockoidc.DefaultUser() + code := srv.CreateAuthorizationCodeFor(user) + nonce := pls.GenerateRandomHex(32) + + r, w := test.RequestResponse( + "GET", + "/session/callback?code="+code+"&state=invalid", + test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))), + ) + + mux.ServeHTTP(w, r) + + require.Equal(t, http.StatusBadRequest, w.Code) + }) + t.Run("with an invalid authorization code grant", func(t *testing.T) { r, w := test.RequestResponse("GET", "/session/callback?code=invalid") mux.ServeHTTP(w, r) - assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, http.StatusBadRequest, w.Code) }) t.Run("with a valid authorization code grant", func(t *testing.T) { user := mockoidc.DefaultUser() code := srv.CreateAuthorizationCodeFor(user) + nonce := pls.GenerateRandomHex(32) - r, w := test.RequestResponse("GET", "/session/callback?code="+code) + r, w := test.RequestResponse( + "GET", + "/session/callback?code="+code+"&state="+nonce, + test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))), + ) mux.ServeHTTP(w, r) diff --git a/app/controllers/sessions/service.go b/app/controllers/sessions/service.go index cbd00fe..0ee692a 100644 --- a/app/controllers/sessions/service.go +++ b/app/controllers/sessions/service.go @@ -2,6 +2,7 @@ package sessions import ( "context" + "errors" "net/http" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc" @@ -31,6 +32,16 @@ func (svc *Service) GenerateRedirectURL() (string, string) { } func (svc *Service) Exchange(r *http.Request) (*oidc.Tokens, error) { + cookies := r.CookiesNamed("oauth_state") + if len(cookies) != 1 { + return nil, errors.New("Missing CSRF token") + } + + state := r.URL.Query().Get("state") + if state != cookies[0].Value { + return nil, errors.New("Invalid CSRF token") + } + ctx := context.WithValue(r.Context(), oauth2.HTTPClient, svc.http) token, err := svc.cfg.Config.Exchange(ctx, r.URL.Query().Get("code")) if err != nil { |
