summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--app/controllers/sessions/controller.go11
-rw-r--r--app/controllers/sessions/controller_test.go7
-rw-r--r--app/controllers/sessions/service_test.go7
-rw-r--r--app/middleware/id_token_test.go5
-rw-r--r--pkg/web/cookie/cookie_test.go2
-rw-r--r--pkg/web/cookie/new.go24
6 files changed, 37 insertions, 19 deletions
diff --git a/app/controllers/sessions/controller.go b/app/controllers/sessions/controller.go
index 5babe7d..ae50e16 100644
--- a/app/controllers/sessions/controller.go
+++ b/app/controllers/sessions/controller.go
@@ -33,10 +33,13 @@ func (c *Controller) New(w http.ResponseWriter, r *http.Request) {
}
url, nonce := c.svc.GenerateRedirectURL()
- cookie := cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))
// This cookie must be sent as part of a redirect that originates from the OIDC Provider
- cookie.SameSite = http.SameSiteLaxMode
- http.SetCookie(w, cookie)
+ http.SetCookie(w, cookie.New(
+ "oauth_state",
+ nonce,
+ cookie.WithSameSite(http.SameSiteLaxMode),
+ cookie.WithExpiration(time.Now().Add(10*time.Minute)),
+ ))
http.Redirect(w, r, url, http.StatusFound)
}
@@ -135,7 +138,7 @@ func (c *Controller) Create(w http.ResponseWriter, r *http.Request) {
return
}
- http.SetCookie(w, cookie.New("session", encoded, tokens.Expiry))
+ http.SetCookie(w, cookie.New("session", encoded, cookie.WithExpiration(tokens.Expiry)))
http.Redirect(w, r, "/dashboard", http.StatusFound)
}
diff --git a/app/controllers/sessions/controller_test.go b/app/controllers/sessions/controller_test.go
index 9ece4f9..c16c6cd 100644
--- a/app/controllers/sessions/controller_test.go
+++ b/app/controllers/sessions/controller_test.go
@@ -6,7 +6,6 @@ import (
"net/http"
"net/url"
"testing"
- "time"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/assert"
@@ -93,7 +92,7 @@ func TestSessions(t *testing.T) {
r, w := test.RequestResponse(
"GET",
"/session/callback?code="+code+"&state=invalid",
- test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))),
+ test.WithCookie(cookie.New("oauth_state", nonce)),
)
mux.ServeHTTP(w, r)
@@ -117,7 +116,7 @@ func TestSessions(t *testing.T) {
r, w := test.RequestResponse(
"GET",
"/session/callback?code="+code+"&state="+nonce,
- test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))),
+ test.WithCookie(cookie.New("oauth_state", nonce)),
)
mux.ServeHTTP(w, r)
@@ -175,7 +174,7 @@ func TestSessions(t *testing.T) {
t.Run("POST /session/destroy", func(t *testing.T) {
t.Run("clears the session cookie", func(t *testing.T) {
- cookie := cookie.New("session", "value", time.Now().Add(5*time.Minute))
+ cookie := cookie.New("session", "value")
r, w := test.RequestResponse("POST", "/session/destroy", test.WithCookie(cookie))
mux.ServeHTTP(w, r)
diff --git a/app/controllers/sessions/service_test.go b/app/controllers/sessions/service_test.go
index f85c9be..c2de6f4 100644
--- a/app/controllers/sessions/service_test.go
+++ b/app/controllers/sessions/service_test.go
@@ -3,7 +3,6 @@ package sessions
import (
"net/http"
"testing"
- "time"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/assert"
@@ -47,7 +46,7 @@ func TestService(t *testing.T) {
r := test.Request(
"GET",
"/session/callback?code="+code+"&state=invalid",
- test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))),
+ test.WithCookie(cookie.New("oauth_state", nonce)),
)
tokens, err := svc.Exchange(r)
@@ -60,7 +59,7 @@ func TestService(t *testing.T) {
r := test.Request(
"GET", "/session/callback?code=invalid",
- test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))),
+ test.WithCookie(cookie.New("oauth_state", nonce)),
)
tokens, err := svc.Exchange(r)
@@ -77,7 +76,7 @@ func TestService(t *testing.T) {
r := test.Request(
"GET",
"/session/callback?code="+code+"&state="+nonce,
- test.WithCookie(cookie.New("oauth_state", nonce, time.Now().Add(10*time.Minute))),
+ test.WithCookie(cookie.New("oauth_state", nonce)),
)
tokens, err := svc.Exchange(r)
diff --git a/app/middleware/id_token_test.go b/app/middleware/id_token_test.go
index 53ac126..02c2901 100644
--- a/app/middleware/id_token_test.go
+++ b/app/middleware/id_token_test.go
@@ -5,7 +5,6 @@ import (
"net/http"
"os"
"testing"
- "time"
"github.com/oauth2-proxy/mockoidc"
"github.com/stretchr/testify/assert"
@@ -57,7 +56,7 @@ func TestIDToken(t *testing.T) {
r, w := test.RequestResponse(
"GET",
"/example",
- test.WithCookie(cookie.New("session", encoded, time.Now().Add(1*time.Hour))),
+ test.WithCookie(cookie.New("session", encoded)),
)
server.ServeHTTP(w, r)
@@ -76,7 +75,7 @@ func TestIDToken(t *testing.T) {
r, w := test.RequestResponse(
"GET",
"/example",
- test.WithCookie(cookie.New("session", "invalid", time.Now().Add(1*time.Hour))),
+ test.WithCookie(cookie.New("session", "invalid")),
)
server.ServeHTTP(w, r)
diff --git a/pkg/web/cookie/cookie_test.go b/pkg/web/cookie/cookie_test.go
index 2f600f4..7256134 100644
--- a/pkg/web/cookie/cookie_test.go
+++ b/pkg/web/cookie/cookie_test.go
@@ -13,7 +13,7 @@ func TestCookie(t *testing.T) {
t.Run("New", func(t *testing.T) {
t.Run("returns a cookie pinned to the HOST", func(t *testing.T) {
env.With(env.Vars{"HOST": "sparkle.example.com"}, func() {
- cookie := New("name", "value", time.Now().Add(1*time.Minute))
+ cookie := New("name", "value")
assert.Equal(t, "sparkle.example.com", cookie.Domain)
assert.True(t, cookie.HttpOnly)
assert.True(t, cookie.Secure)
diff --git a/pkg/web/cookie/new.go b/pkg/web/cookie/new.go
index a3cb200..08b796a 100644
--- a/pkg/web/cookie/new.go
+++ b/pkg/web/cookie/new.go
@@ -10,12 +10,10 @@ import (
type CookieOption pls.Option[*http.Cookie]
-func New(name, value string, expires time.Time, options ...CookieOption) *http.Cookie {
+func New(name, value string, options ...CookieOption) *http.Cookie {
cookie := &http.Cookie{
Name: name,
Value: value, // TODO:: digitally sign the value
- Expires: expires,
- MaxAge: int(time.Until(expires).Seconds()),
Path: "/",
HttpOnly: true,
Secure: true,
@@ -29,3 +27,23 @@ func New(name, value string, expires time.Time, options ...CookieOption) *http.C
return cookie
}
+
+func With(with func(*http.Cookie)) CookieOption {
+ return func(c *http.Cookie) *http.Cookie {
+ with(c)
+ return c
+ }
+}
+
+func WithSameSite(value http.SameSite) CookieOption {
+ return With(func(c *http.Cookie) {
+ c.SameSite = value
+ })
+}
+
+func WithExpiration(expires time.Time) CookieOption {
+ return With(func(c *http.Cookie) {
+ c.Expires = expires
+ c.MaxAge = int(time.Until(expires).Seconds())
+ })
+}