summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--app/controllers/dashboard/controller.go2
-rw-r--r--app/controllers/dashboard/controller_test.go4
-rw-r--r--app/controllers/sessions/controller.go2
-rw-r--r--app/controllers/sparkles/controller.go2
-rw-r--r--app/middleware/require_user.go15
-rw-r--r--app/middleware/require_user_test.go6
6 files changed, 12 insertions, 19 deletions
diff --git a/app/controllers/dashboard/controller.go b/app/controllers/dashboard/controller.go
index ef5b18d..220871f 100644
--- a/app/controllers/dashboard/controller.go
+++ b/app/controllers/dashboard/controller.go
@@ -17,7 +17,7 @@ func New() *Controller {
}
func (c *Controller) MountTo(mux *http.ServeMux) {
- requireUser := middleware.RequireUser(http.StatusFound, "/")
+ requireUser := middleware.RequireUser()
mux.Handle("GET /dashboard", requireUser(http.HandlerFunc(c.Show)))
}
diff --git a/app/controllers/dashboard/controller_test.go b/app/controllers/dashboard/controller_test.go
index ced3fd5..629a03a 100644
--- a/app/controllers/dashboard/controller_test.go
+++ b/app/controllers/dashboard/controller_test.go
@@ -22,9 +22,7 @@ func TestController(t *testing.T) {
mux.ServeHTTP(w, r)
- assert.Equal(t, http.StatusFound, w.Code)
- location := w.HeaderMap.Get("Location")
- assert.Equal(t, "/", location)
+ assert.Equal(t, http.StatusNotFound, w.Code)
})
})
diff --git a/app/controllers/sessions/controller.go b/app/controllers/sessions/controller.go
index 08002a2..9a65ae3 100644
--- a/app/controllers/sessions/controller.go
+++ b/app/controllers/sessions/controller.go
@@ -136,5 +136,5 @@ func (c *Controller) Create(w http.ResponseWriter, r *http.Request) {
}
http.SetCookie(w, cookie.New("session", encoded, tokens.Expiry))
- http.Redirect(w, r, "/dashboard", http.StatusFound)
+ http.Redirect(w, r, "/", http.StatusFound)
}
diff --git a/app/controllers/sparkles/controller.go b/app/controllers/sparkles/controller.go
index 9c319b2..e0da8c4 100644
--- a/app/controllers/sparkles/controller.go
+++ b/app/controllers/sparkles/controller.go
@@ -7,8 +7,10 @@ import (
"github.com/xlgmokha/x/pkg/mapper"
"github.com/xlgmokha/x/pkg/serde"
"github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/cfg"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/middleware"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/views"
)
type Controller struct {
diff --git a/app/middleware/require_user.go b/app/middleware/require_user.go
index 8df4fd7..d0d5355 100644
--- a/app/middleware/require_user.go
+++ b/app/middleware/require_user.go
@@ -2,21 +2,16 @@ package middleware
import (
"net/http"
-
- "github.com/xlgmokha/x/pkg/x"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/cfg"
)
-func RequireUser(code int, url string) func(http.Handler) http.Handler {
+func RequireUser() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user := cfg.CurrentUser.From(r.Context())
- if x.IsZero(user) {
- http.Redirect(w, r, url, code)
- return
+ if IsLoggedIn(r) {
+ next.ServeHTTP(w, r)
+ } else {
+ w.WriteHeader(http.StatusNotFound)
}
-
- next.ServeHTTP(w, r)
})
}
}
diff --git a/app/middleware/require_user_test.go b/app/middleware/require_user_test.go
index 17c0276..48afff7 100644
--- a/app/middleware/require_user_test.go
+++ b/app/middleware/require_user_test.go
@@ -4,7 +4,6 @@ import (
"net/http"
"testing"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/cfg"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
@@ -12,7 +11,7 @@ import (
)
func TestRequireUser(t *testing.T) {
- middleware := RequireUser(http.StatusFound, "/login")
+ middleware := RequireUser()
t.Run("when a user is not logged in", func(t *testing.T) {
t.Run("redirects to the homepage", func(t *testing.T) {
@@ -23,8 +22,7 @@ func TestRequireUser(t *testing.T) {
}))
server.ServeHTTP(w, r)
- require.Equal(t, http.StatusFound, w.Code)
- assert.Equal(t, "/login", w.Header().Get("Location"))
+ require.Equal(t, http.StatusNotFound, w.Code)
})
})