From ea841ab274630cff287a586d9799663a28c708fc Mon Sep 17 00:00:00 2001 From: mo khan Date: Wed, 30 Apr 2025 12:18:33 -0600 Subject: refactor: extract Option[T] and cleaner API for creating cookies --- app/controllers/sessions/controller.go | 11 +++++++---- app/controllers/sessions/controller_test.go | 7 +++---- app/controllers/sessions/service_test.go | 7 +++---- 3 files changed, 13 insertions(+), 12 deletions(-) (limited to 'app/controllers/sessions') diff --git a/app/controllers/sessions/controller.go b/app/controllers/sessions/controller.go index 5babe7d..ae50e16 100644 --- a/app/controllers/sessions/controller.go +++ b/app/controllers/sessions/controller.go @@ -33,10 +33,13 @@ func (c *Controller) New(w http.ResponseWriter, r *http.Request) { } url, nonce := c.svc.GenerateRedirectURL() - cookie := cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute)) // This cookie must be sent as part of a redirect that originates from the OIDC Provider - cookie.SameSite = http.SameSiteLaxMode - http.SetCookie(w, cookie) + http.SetCookie(w, cookie.New( + "oauth_state", + nonce, + cookie.WithSameSite(http.SameSiteLaxMode), + cookie.WithExpiration(time.Now().Add(10*time.Minute)), + )) http.Redirect(w, r, url, http.StatusFound) } @@ -135,7 +138,7 @@ func (c *Controller) Create(w http.ResponseWriter, r *http.Request) { return } - http.SetCookie(w, cookie.New("session", encoded, tokens.Expiry)) + http.SetCookie(w, cookie.New("session", encoded, cookie.WithExpiration(tokens.Expiry))) http.Redirect(w, r, "/dashboard", http.StatusFound) } diff --git a/app/controllers/sessions/controller_test.go b/app/controllers/sessions/controller_test.go index 9ece4f9..c16c6cd 100644 --- a/app/controllers/sessions/controller_test.go +++ b/app/controllers/sessions/controller_test.go @@ -6,7 +6,6 @@ import ( "net/http" "net/url" "testing" - "time" "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" @@ -93,7 +92,7 @@ func TestSessions(t *testing.T) { r, w := test.RequestResponse( "GET", "/session/callback?code="+code+"&state=invalid", - test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))), + test.WithCookie(cookie.New("oauth_state", nonce)), ) mux.ServeHTTP(w, r) @@ -117,7 +116,7 @@ func TestSessions(t *testing.T) { r, w := test.RequestResponse( "GET", "/session/callback?code="+code+"&state="+nonce, - test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))), + test.WithCookie(cookie.New("oauth_state", nonce)), ) mux.ServeHTTP(w, r) @@ -175,7 +174,7 @@ func TestSessions(t *testing.T) { t.Run("POST /session/destroy", func(t *testing.T) { t.Run("clears the session cookie", func(t *testing.T) { - cookie := cookie.New("session", "value", time.Now().Add(5*time.Minute)) + cookie := cookie.New("session", "value") r, w := test.RequestResponse("POST", "/session/destroy", test.WithCookie(cookie)) mux.ServeHTTP(w, r) diff --git a/app/controllers/sessions/service_test.go b/app/controllers/sessions/service_test.go index f85c9be..c2de6f4 100644 --- a/app/controllers/sessions/service_test.go +++ b/app/controllers/sessions/service_test.go @@ -3,7 +3,6 @@ package sessions import ( "net/http" "testing" - "time" "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" @@ -47,7 +46,7 @@ func TestService(t *testing.T) { r := test.Request( "GET", "/session/callback?code="+code+"&state=invalid", - test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))), + test.WithCookie(cookie.New("oauth_state", nonce)), ) tokens, err := svc.Exchange(r) @@ -60,7 +59,7 @@ func TestService(t *testing.T) { r := test.Request( "GET", "/session/callback?code=invalid", - test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))), + test.WithCookie(cookie.New("oauth_state", nonce)), ) tokens, err := svc.Exchange(r) @@ -77,7 +76,7 @@ func TestService(t *testing.T) { r := test.Request( "GET", "/session/callback?code="+code+"&state="+nonce, - test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))), + test.WithCookie(cookie.New("oauth_state", nonce)), ) tokens, err := svc.Exchange(r) -- cgit v1.2.3