package middleware import ( "context" "net/http" "os" "testing" "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xlgmokha/x/pkg/log" "github.com/xlgmokha/x/pkg/test" "github.com/xlgmokha/x/pkg/x" xcfg "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/cfg" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web" "golang.org/x/oauth2" ) func TestIDToken(t *testing.T) { srv := oidc.NewTestServer(t) defer srv.Close() client := &http.Client{Transport: &web.Transport{Logger: log.New(os.Stdout, log.Fields{})}} cfg := srv.MockOIDC.Config() ctx := context.WithValue(t.Context(), oauth2.HTTPClient, client) openID, err := oidc.New( ctx, srv.Issuer(), cfg.ClientID, cfg.ClientSecret, "https://example.com/oauth/callback", ) require.NoError(t, err) middleware := IDToken(openID, IDTokenFromSessionCookie) t.Run("when an active session cookie is provided", func(t *testing.T) { t.Run("attaches the token to the request context", func(t *testing.T) { user := mockoidc.DefaultUser() token, rawIDToken := srv.CreateTokensFor(user) tokens := &oidc.Tokens{Token: token, IDToken: oidc.RawToken(rawIDToken)} encoded := x.Must(tokens.ToBase64String()) server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { token := xcfg.IDToken.From(r.Context()) require.NotNil(t, token) assert.Equal(t, user.Subject, token.Subject) w.WriteHeader(http.StatusTeapot) })) r, w := test.RequestResponse( "GET", "/example", test.WithCookie(web.NewCookie("session", encoded)), ) server.ServeHTTP(w, r) assert.Equal(t, http.StatusTeapot, w.Code) }) }) t.Run("when an invalid session cookie is provided", func(t *testing.T) { t.Run("forwards the request", func(t *testing.T) { server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Nil(t, xcfg.IDToken.From(r.Context())) w.WriteHeader(http.StatusTeapot) })) r, w := test.RequestResponse( "GET", "/example", test.WithCookie(web.NewCookie("session", "invalid")), ) server.ServeHTTP(w, r) assert.Equal(t, http.StatusTeapot, w.Code) }) }) t.Run("when no cookies are provided", func(t *testing.T) { t.Run("forwards the request", func(t *testing.T) { server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Nil(t, xcfg.IDToken.From(r.Context())) w.WriteHeader(http.StatusTeapot) })) r, w := test.RequestResponse("GET", "/example") server.ServeHTTP(w, r) assert.Equal(t, http.StatusTeapot, w.Code) }) }) }