diff options
Diffstat (limited to 'pkg/web')
| -rw-r--r-- | pkg/web/middleware/unpack_token.go | 8 | ||||
| -rw-r--r-- | pkg/web/middleware/unpack_token_test.go | 2 |
2 files changed, 5 insertions, 5 deletions
diff --git a/pkg/web/middleware/unpack_token.go b/pkg/web/middleware/unpack_token.go index d31f9cc..0b182a0 100644 --- a/pkg/web/middleware/unpack_token.go +++ b/pkg/web/middleware/unpack_token.go @@ -9,9 +9,9 @@ import ( "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc" ) -type TokenParser func(*http.Request) string +type TokenParser func(*http.Request) oidc.RawIDToken -func fromSessionCookie(r *http.Request) string { +func FromSessionCookie(r *http.Request) oidc.RawIDToken { cookies := r.CookiesNamed("session") if len(cookies) != 1 { @@ -28,7 +28,7 @@ func fromSessionCookie(r *http.Request) string { } func UnpackToken(cfg *oidc.OpenID) func(http.Handler) http.Handler { - parsers := []TokenParser{fromSessionCookie} + parsers := []TokenParser{FromSessionCookie} return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -36,7 +36,7 @@ func UnpackToken(cfg *oidc.OpenID) func(http.Handler) http.Handler { rawIDToken := parser(r) if !x.IsZero(rawIDToken) { verifier := cfg.Provider.VerifierContext(r.Context(), cfg.OIDCConfig) - idToken, err := verifier.Verify(r.Context(), rawIDToken) + idToken, err := verifier.Verify(r.Context(), rawIDToken.String()) if err != nil { log.WithFields(r.Context(), log.Fields{"error": err}) } else { diff --git a/pkg/web/middleware/unpack_token_test.go b/pkg/web/middleware/unpack_token_test.go index f2250bc..116e88f 100644 --- a/pkg/web/middleware/unpack_token_test.go +++ b/pkg/web/middleware/unpack_token_test.go @@ -43,7 +43,7 @@ func TestUnpackToken(t *testing.T) { user := mockoidc.DefaultUser() token, rawIDToken := srv.CreateTokensFor(user) - tokens := &oidc.Tokens{Token: token, IDToken: rawIDToken} + tokens := &oidc.Tokens{Token: token, IDToken: oidc.RawIDToken(rawIDToken)} encoded := x.Must(tokens.ToBase64String()) server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
