package middleware import ( "net/http" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test" ) func TestRequireUser(t *testing.T) { middleware := RequireUser(http.StatusFound, "/login") t.Run("when a user is not logged in", func(t *testing.T) { t.Run("redirects to the homepage", func(t *testing.T) { r, w := test.RequestResponse("GET", "/example") server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Fail(t, "unexpected call to handler") })) server.ServeHTTP(w, r) require.Equal(t, http.StatusFound, w.Code) assert.Equal(t, "/login", w.Header().Get("Location")) }) }) t.Run("when a user is logged in", func(t *testing.T) { t.Run("forwards the request", func(t *testing.T) { r, w := test.RequestResponse("GET", "/example", test.WithContextKeyValue(t.Context(), key.CurrentUser, &domain.User{})) server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) })) server.ServeHTTP(w, r) require.Equal(t, http.StatusTeapot, w.Code) }) }) }