diff options
| author | mo khan <mo@mokhan.ca> | 2025-04-21 13:15:39 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-04-21 13:15:39 -0600 |
| commit | bbd583211275e686f1292a93c828fde64fdda8ed (patch) | |
| tree | 787ff49cfb08851a80fe41e6291393d644957252 /pkg/web/middleware | |
| parent | 1ece3b42051d26050cd612a3ed9a20122d501746 (diff) | |
feat: provision new users on-demand
Diffstat (limited to 'pkg/web/middleware')
| -rw-r--r-- | pkg/web/middleware/user.go | 12 | ||||
| -rw-r--r-- | pkg/web/middleware/user_test.go | 31 |
2 files changed, 33 insertions, 10 deletions
diff --git a/pkg/web/middleware/user.go b/pkg/web/middleware/user.go index b01ae48..ab4e3cc 100644 --- a/pkg/web/middleware/user.go +++ b/pkg/web/middleware/user.go @@ -3,6 +3,7 @@ package middleware import ( "net/http" + "github.com/xlgmokha/x/pkg/log" "github.com/xlgmokha/x/pkg/x" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/db" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" @@ -20,10 +21,15 @@ func User(db db.Repository[*domain.User]) func(http.Handler) http.Handler { user := db.Find(idToken.Subject) if x.IsZero(user) { - next.ServeHTTP(w, r) - } else { - next.ServeHTTP(w, r.WithContext(key.CurrentUser.With(r.Context(), user))) + user = &domain.User{ID: idToken.Subject} + if err := db.Save(user); err != nil { + log.WithFields(r.Context(), log.Fields{"error": err}) + next.ServeHTTP(w, r) + return + } } + + next.ServeHTTP(w, r.WithContext(key.CurrentUser.With(r.Context(), user))) }) } } diff --git a/pkg/web/middleware/user_test.go b/pkg/web/middleware/user_test.go index cde7dec..447d54d 100644 --- a/pkg/web/middleware/user_test.go +++ b/pkg/web/middleware/user_test.go @@ -10,6 +10,7 @@ import ( "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test" ) @@ -17,20 +18,20 @@ func TestUser(t *testing.T) { repository := db.NewRepository[*domain.User]() middleware := User(repository) - knownUser := &domain.User{ID: "1"} + knownUser := &domain.User{ID: pls.GenerateULID()} require.NoError(t, repository.Save(knownUser)) - t.Run("when an ID Token is found in the context", func(t *testing.T) { - t.Run("when the user is found in the db", func(t *testing.T) { + t.Run("when ID Token is provided", func(t *testing.T) { + t.Run("when user is known", func(t *testing.T) { server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := key.CurrentUser.From(r.Context()) require.NotNil(t, user) + assert.Equal(t, knownUser.ID, user.ID) w.WriteHeader(http.StatusTeapot) })) - idToken := &oidc.IDToken{Subject: knownUser.ID} - ctx := key.IDToken.With(t.Context(), idToken) + ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID}) r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx)) server.ServeHTTP(w, r) @@ -38,12 +39,28 @@ func TestUser(t *testing.T) { assert.Equal(t, http.StatusTeapot, w.Code) }) - t.Run("when the user is not found in the db", func(t *testing.T) { + t.Run("when user is unknown", func(t *testing.T) { + unknownID := pls.GenerateULID() + server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := key.CurrentUser.From(r.Context()) + require.NotNil(t, user) + assert.Equal(t, unknownID, user.ID) + + w.WriteHeader(http.StatusTeapot) + })) + + ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: unknownID}) + + r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx)) + server.ServeHTTP(w, r) + + assert.Equal(t, http.StatusTeapot, w.Code) + require.NotNil(t, repository.Find(unknownID)) }) }) - t.Run("when an ID Token is not found in the context", func(t *testing.T) { + t.Run("when ID Token is not provided", func(t *testing.T) { server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := key.CurrentUser.From(r.Context()) require.Nil(t, user) |
