summaryrefslogtreecommitdiff
path: root/app/middleware
diff options
context:
space:
mode:
Diffstat (limited to 'app/middleware')
-rw-r--r--app/middleware/init.go15
-rw-r--r--app/middleware/is_logged_in.go3
-rw-r--r--app/middleware/request_parser.go9
-rw-r--r--app/middleware/require_user_test.go3
-rw-r--r--app/middleware/user.go11
-rw-r--r--app/middleware/user_test.go14
6 files changed, 22 insertions, 33 deletions
diff --git a/app/middleware/init.go b/app/middleware/init.go
index 23c524d..4ff10c4 100644
--- a/app/middleware/init.go
+++ b/app/middleware/init.go
@@ -4,21 +4,16 @@ import (
"net/http"
"github.com/xlgmokha/x/pkg/mapper"
- "github.com/xlgmokha/x/pkg/x"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
)
func init() {
mapper.Register(func(h http.Header) *domain.User {
- subject := h.Get("x-jwt-claim-sub")
- if x.IsPresent(subject) {
- return &domain.User{
- ID: domain.ID(h.Get("x-jwt-claim-sub")),
- Username: h.Get("x-jwt-claim-username"),
- ProfileURL: h.Get("x-jwt-claim-profile-url"),
- Picture: h.Get("x-jwt-claim-picture-url"),
- }
+ return &domain.User{
+ ID: domain.ID(h.Get("x-id-jwt-claim-sub")),
+ Username: h.Get("x-id-jwt-claim-username"),
+ ProfileURL: h.Get("x-id-jwt-claim-profile-url"),
+ Picture: h.Get("x-id-jwt-claim-picture-url"),
}
- return nil
})
}
diff --git a/app/middleware/is_logged_in.go b/app/middleware/is_logged_in.go
index e2f0445..f70a03b 100644
--- a/app/middleware/is_logged_in.go
+++ b/app/middleware/is_logged_in.go
@@ -8,5 +8,6 @@ import (
)
var IsLoggedIn x.Predicate[*http.Request] = x.Predicate[*http.Request](func(r *http.Request) bool {
- return x.IsPresent(cfg.CurrentUser.From(r.Context()))
+ user := cfg.CurrentUser.From(r.Context())
+ return x.IsPresent(user) && x.IsPresent(user.ID)
})
diff --git a/app/middleware/request_parser.go b/app/middleware/request_parser.go
deleted file mode 100644
index dfc5d3a..0000000
--- a/app/middleware/request_parser.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package middleware
-
-import (
- "net/http"
-
- "github.com/xlgmokha/x/pkg/x"
-)
-
-type RequestParser[T any] x.Mapper[*http.Request, T]
diff --git a/app/middleware/require_user_test.go b/app/middleware/require_user_test.go
index 07cbf92..20b5f94 100644
--- a/app/middleware/require_user_test.go
+++ b/app/middleware/require_user_test.go
@@ -28,7 +28,8 @@ func TestRequireUser(t *testing.T) {
t.Run("when a user is logged in", func(t *testing.T) {
t.Run("forwards the request", func(t *testing.T) {
- r, w := test.RequestResponse("GET", "/example", test.WithContextKeyValue(t.Context(), cfg.CurrentUser, &domain.User{}))
+ user := &domain.User{ID: domain.ID("1")}
+ r, w := test.RequestResponse("GET", "/example", test.WithContextKeyValue(t.Context(), cfg.CurrentUser, user))
server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTeapot)
diff --git a/app/middleware/user.go b/app/middleware/user.go
index 2b2dd17..317671e 100644
--- a/app/middleware/user.go
+++ b/app/middleware/user.go
@@ -13,11 +13,12 @@ func User() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.WithFields(r.Context(), log.Fields{
- "payload": r.Header.Get("x-jwt-payload"),
- "photo": r.Header.Get("x-jwt-claim-picture-url"),
- "profile": r.Header.Get("x-jwt-claim-profile-url"),
- "sub": r.Header.Get("x-jwt-claim-sub"),
- "username": r.Header.Get("x-jwt-claim-username"),
+ "authorization": r.Header.Get("Authorization"),
+ "payload": r.Header.Get("x-id-jwt-payload"),
+ "photo": r.Header.Get("x-id-jwt-claim-picture-url"),
+ "profile": r.Header.Get("x-id-jwt-claim-profile-url"),
+ "sub": r.Header.Get("x-id-jwt-claim-sub"),
+ "username": r.Header.Get("x-id-jwt-claim-username"),
})
next.ServeHTTP(w, r.WithContext(cfg.CurrentUser.With(
diff --git a/app/middleware/user_test.go b/app/middleware/user_test.go
index 66ca121..c778c98 100644
--- a/app/middleware/user_test.go
+++ b/app/middleware/user_test.go
@@ -14,9 +14,9 @@ import (
func TestUser(t *testing.T) {
middleware := User()
- t.Run("when x-jwt-claim-* headers are not provided", func(t *testing.T) {
+ t.Run("when x-id-jwt-claim-* headers are not provided", func(t *testing.T) {
server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- require.Nil(t, cfg.CurrentUser.From(r.Context()))
+ require.False(t, IsLoggedIn(r))
w.WriteHeader(http.StatusTeapot)
}))
@@ -27,7 +27,7 @@ func TestUser(t *testing.T) {
assert.Equal(t, http.StatusTeapot, w.Code)
})
- t.Run("when x-jwt-claim-* headers are provided", func(t *testing.T) {
+ t.Run("when x-id-jwt-claim-* headers are provided", func(t *testing.T) {
server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.True(t, IsLoggedIn(r))
@@ -43,10 +43,10 @@ func TestUser(t *testing.T) {
}))
r, w := test.RequestResponse("GET", "/",
- test.WithRequestHeader("x-jwt-claim-sub", "1"),
- test.WithRequestHeader("x-jwt-claim-username", "root"),
- test.WithRequestHeader("x-jwt-claim-profile-url", "https://gitlab.com/tanuki"),
- test.WithRequestHeader("x-jwt-claim-picture-url", "https://example.com/profile.png"),
+ test.WithRequestHeader("x-id-jwt-claim-sub", "1"),
+ test.WithRequestHeader("x-id-jwt-claim-username", "root"),
+ test.WithRequestHeader("x-id-jwt-claim-profile-url", "https://gitlab.com/tanuki"),
+ test.WithRequestHeader("x-id-jwt-claim-picture-url", "https://example.com/profile.png"),
)
server.ServeHTTP(w, r)