summaryrefslogtreecommitdiff
path: root/app/middleware
diff options
context:
space:
mode:
Diffstat (limited to 'app/middleware')
-rw-r--r--app/middleware/from_cookie.go15
-rw-r--r--app/middleware/from_custom_header.go9
-rw-r--r--app/middleware/init.go2
-rw-r--r--app/middleware/raw_token.go7
-rw-r--r--app/middleware/token_parser.go3
-rw-r--r--app/middleware/user.go23
-rw-r--r--app/middleware/user_parser.go16
-rw-r--r--app/middleware/user_parser_test.go36
-rw-r--r--app/middleware/user_test.go2
9 files changed, 18 insertions, 95 deletions
diff --git a/app/middleware/from_cookie.go b/app/middleware/from_cookie.go
deleted file mode 100644
index 316d6e4..0000000
--- a/app/middleware/from_cookie.go
+++ /dev/null
@@ -1,15 +0,0 @@
-package middleware
-
-import "net/http"
-
-func FromCookie(name string) TokenParser {
- return func(r *http.Request) RawToken {
- cookies := r.CookiesNamed(name)
-
- if len(cookies) != 1 {
- return ""
- }
-
- return RawToken(cookies[0].Value)
- }
-}
diff --git a/app/middleware/from_custom_header.go b/app/middleware/from_custom_header.go
deleted file mode 100644
index f385911..0000000
--- a/app/middleware/from_custom_header.go
+++ /dev/null
@@ -1,9 +0,0 @@
-package middleware
-
-import "net/http"
-
-func FromCustomHeader(name string) TokenParser {
- return func(r *http.Request) RawToken {
- return RawToken(r.Header.Get(name))
- }
-}
diff --git a/app/middleware/init.go b/app/middleware/init.go
index 5bf84f6..23c524d 100644
--- a/app/middleware/init.go
+++ b/app/middleware/init.go
@@ -13,7 +13,7 @@ func init() {
subject := h.Get("x-jwt-claim-sub")
if x.IsPresent(subject) {
return &domain.User{
- ID: domain.ID(subject),
+ ID: domain.ID(h.Get("x-jwt-claim-sub")),
Username: h.Get("x-jwt-claim-username"),
ProfileURL: h.Get("x-jwt-claim-profile-url"),
Picture: h.Get("x-jwt-claim-picture-url"),
diff --git a/app/middleware/raw_token.go b/app/middleware/raw_token.go
deleted file mode 100644
index f7aa264..0000000
--- a/app/middleware/raw_token.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package middleware
-
-type RawToken string
-
-func (r RawToken) String() string {
- return string(r)
-}
diff --git a/app/middleware/token_parser.go b/app/middleware/token_parser.go
deleted file mode 100644
index 1a92760..0000000
--- a/app/middleware/token_parser.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package middleware
-
-type TokenParser RequestParser[RawToken]
diff --git a/app/middleware/user.go b/app/middleware/user.go
index 90bf6aa..2b2dd17 100644
--- a/app/middleware/user.go
+++ b/app/middleware/user.go
@@ -3,20 +3,27 @@ package middleware
import (
"net/http"
- "github.com/xlgmokha/x/pkg/x"
+ "github.com/xlgmokha/x/pkg/log"
+ "github.com/xlgmokha/x/pkg/mapper"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/cfg"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
)
func User() func(http.Handler) http.Handler {
- parser := UserParser()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- user := parser(r)
- if x.IsPresent(user) {
- next.ServeHTTP(w, r.WithContext(cfg.CurrentUser.With(r.Context(), user)))
- } else {
- next.ServeHTTP(w, r)
- }
+ log.WithFields(r.Context(), log.Fields{
+ "payload": r.Header.Get("x-jwt-payload"),
+ "photo": r.Header.Get("x-jwt-claim-picture-url"),
+ "profile": r.Header.Get("x-jwt-claim-profile-url"),
+ "sub": r.Header.Get("x-jwt-claim-sub"),
+ "username": r.Header.Get("x-jwt-claim-username"),
+ })
+
+ next.ServeHTTP(w, r.WithContext(cfg.CurrentUser.With(
+ r.Context(),
+ mapper.MapFrom[http.Header, *domain.User](r.Header),
+ )))
})
}
}
diff --git a/app/middleware/user_parser.go b/app/middleware/user_parser.go
deleted file mode 100644
index dfa0cce..0000000
--- a/app/middleware/user_parser.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package middleware
-
-import (
- "net/http"
-
- "github.com/xlgmokha/x/pkg/log"
- "github.com/xlgmokha/x/pkg/mapper"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
-)
-
-func UserParser() RequestParser[*domain.User] {
- return func(r *http.Request) *domain.User {
- log.WithFields(r.Context(), log.Fields{"header": r.Header})
- return mapper.MapFrom[http.Header, *domain.User](r.Header)
- }
-}
diff --git a/app/middleware/user_parser_test.go b/app/middleware/user_parser_test.go
deleted file mode 100644
index 2127a10..0000000
--- a/app/middleware/user_parser_test.go
+++ /dev/null
@@ -1,36 +0,0 @@
-package middleware
-
-import (
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "github.com/xlgmokha/x/pkg/test"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
-)
-
-func TestUserParser(t *testing.T) {
- parser := UserParser()
-
- t.Run("when x-jwt-claim-* headers are not provided", func(t *testing.T) {
- t.Run("forwards the request without a current user attached to the request", func(t *testing.T) {
- assert.Nil(t, parser(test.Request("GET", "/")))
- })
- })
-
- t.Run("when x-jwt-claim-* headers are provided", func(t *testing.T) {
- r := test.Request("GET", "/",
- test.WithRequestHeader("x-jwt-claim-sub", "1"),
- test.WithRequestHeader("x-jwt-claim-username", "root"),
- test.WithRequestHeader("x-jwt-claim-profile-url", "https://gitlab.com/tanuki"),
- test.WithRequestHeader("x-jwt-claim-picture-url", "https://example.com/profile.png"),
- )
-
- result := parser(r)
- require.NotNil(t, result)
- assert.Equal(t, domain.ID("1"), result.ID)
- assert.Equal(t, "root", result.Username)
- assert.Equal(t, "https://gitlab.com/tanuki", result.ProfileURL)
- assert.Equal(t, "https://example.com/profile.png", result.Picture)
- })
-}
diff --git a/app/middleware/user_test.go b/app/middleware/user_test.go
index c5fa7ed..66ca121 100644
--- a/app/middleware/user_test.go
+++ b/app/middleware/user_test.go
@@ -29,6 +29,8 @@ func TestUser(t *testing.T) {
t.Run("when x-jwt-claim-* headers are provided", func(t *testing.T) {
server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.True(t, IsLoggedIn(r))
+
user := cfg.CurrentUser.From(r.Context())
require.NotNil(t, user)