diff options
| -rw-r--r-- | pkg/test/http.go | 5 | ||||
| -rw-r--r-- | pkg/web/middleware/require_user.go | 22 | ||||
| -rw-r--r-- | pkg/web/middleware/require_user_test.go | 43 |
3 files changed, 70 insertions, 0 deletions
diff --git a/pkg/test/http.go b/pkg/test/http.go index 54712f1..280aef6 100644 --- a/pkg/test/http.go +++ b/pkg/test/http.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" + xcontext "github.com/xlgmokha/x/pkg/context" "github.com/xlgmokha/x/pkg/serde" "github.com/xlgmokha/x/pkg/x" ) @@ -55,6 +56,10 @@ func WithContext(ctx context.Context) RequestOption { } } +func WithContextKeyValue[T any](ctx context.Context, key xcontext.Key[T], item T) RequestOption { + return WithContext(key.With(ctx, item)) +} + func WithCookie(cookie *http.Cookie) RequestOption { return func(r *http.Request) *http.Request { r.AddCookie(cookie) diff --git a/pkg/web/middleware/require_user.go b/pkg/web/middleware/require_user.go new file mode 100644 index 0000000..e81d5b5 --- /dev/null +++ b/pkg/web/middleware/require_user.go @@ -0,0 +1,22 @@ +package middleware + +import ( + "net/http" + + "github.com/xlgmokha/x/pkg/x" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key" +) + +func RequireUser(code int, url string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := key.CurrentUser.From(r.Context()) + if x.IsZero(user) { + http.Redirect(w, r, url, code) + return + } + + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/web/middleware/require_user_test.go b/pkg/web/middleware/require_user_test.go new file mode 100644 index 0000000..ac764f6 --- /dev/null +++ b/pkg/web/middleware/require_user_test.go @@ -0,0 +1,43 @@ +package middleware + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test" +) + +func TestRequireUser(t *testing.T) { + middleware := RequireUser(http.StatusFound, "/login") + + t.Run("when a user is not logged in", func(t *testing.T) { + t.Run("redirects to the homepage", func(t *testing.T) { + r, w := test.RequestResponse("GET", "/example") + + server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Fail(t, "unexpected call to handler") + })) + server.ServeHTTP(w, r) + + require.Equal(t, http.StatusFound, w.Code) + assert.Equal(t, "/login", w.Header().Get("Location")) + }) + }) + + t.Run("when a user is logged in", func(t *testing.T) { + t.Run("forwards the request", func(t *testing.T) { + r, w := test.RequestResponse("GET", "/example", test.WithContextKeyValue(t.Context(), key.CurrentUser, &domain.User{})) + + server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + })) + server.ServeHTTP(w, r) + + require.Equal(t, http.StatusTeapot, w.Code) + }) + }) +} |
