From 1ece3b42051d26050cd612a3ed9a20122d501746 Mon Sep 17 00:00:00 2001 From: mo khan Date: Mon, 21 Apr 2025 13:06:56 -0600 Subject: feat: attach current user if they are in the db --- pkg/db/entity.go | 7 ++++++ pkg/db/in_memory_repository.go | 39 +++++++++++++++++++++++++++++ pkg/db/in_memory_repository_test.go | 50 +++++++++++++++++++++++++++++++++++++ pkg/db/repository.go | 36 +------------------------- pkg/db/repository_test.go | 33 ------------------------ pkg/domain/user.go | 13 ++++++++++ pkg/web/middleware/user.go | 24 +++++++++++++++--- pkg/web/middleware/user_test.go | 24 ++++++++++++++++-- 8 files changed, 153 insertions(+), 73 deletions(-) create mode 100644 pkg/db/entity.go create mode 100644 pkg/db/in_memory_repository.go create mode 100644 pkg/db/in_memory_repository_test.go delete mode 100644 pkg/db/repository_test.go (limited to 'pkg') diff --git a/pkg/db/entity.go b/pkg/db/entity.go new file mode 100644 index 0000000..1dcf4c3 --- /dev/null +++ b/pkg/db/entity.go @@ -0,0 +1,7 @@ +package db + +type Entity interface { + GetID() string + SetID(id string) error + Validate() error +} diff --git a/pkg/db/in_memory_repository.go b/pkg/db/in_memory_repository.go new file mode 100644 index 0000000..5859c0b --- /dev/null +++ b/pkg/db/in_memory_repository.go @@ -0,0 +1,39 @@ +package db + +import ( + "github.com/xlgmokha/x/pkg/x" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls" +) + +type inMemoryRepository[T Entity] struct { + items []T +} + +func NewRepository[T Entity]() Repository[T] { + return &inMemoryRepository[T]{ + items: []T{}, + } +} + +func (r *inMemoryRepository[T]) All() []T { + return r.items +} + +func (r *inMemoryRepository[T]) Find(id string) T { + return x.Find(r.All(), func(item T) bool { + return item.GetID() == id + }) +} + +func (r *inMemoryRepository[T]) Save(item T) error { + if err := item.Validate(); err != nil { + return err + } + + if item.GetID() == "" { + item.SetID(pls.GenerateULID()) + } + + r.items = append(r.items, item) + return nil +} diff --git a/pkg/db/in_memory_repository_test.go b/pkg/db/in_memory_repository_test.go new file mode 100644 index 0000000..382a656 --- /dev/null +++ b/pkg/db/in_memory_repository_test.go @@ -0,0 +1,50 @@ +package db + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" +) + +func TestInMemoryRepository(t *testing.T) { + storage := NewRepository[*domain.Sparkle]() + + t.Run("Save", func(t *testing.T) { + t.Run("an invalid Sparkle", func(t *testing.T) { + err := storage.Save(&domain.Sparkle{Reason: "because"}) + + assert.Error(t, err) + assert.Equal(t, 0, len(storage.All())) + }) + + t.Run("a valid Sparkle", func(t *testing.T) { + sparkle := &domain.Sparkle{Sparklee: "@tanuki", Reason: "because"} + require.NoError(t, storage.Save(sparkle)) + + sparkles := storage.All() + assert.Equal(t, 1, len(sparkles)) + assert.NotEmpty(t, sparkles[0].ID) + assert.Equal(t, "@tanuki", sparkles[0].Sparklee) + assert.Equal(t, "because", sparkles[0].Reason) + }) + }) + + t.Run("Find", func(t *testing.T) { + t.Run("when the entity exists", func(t *testing.T) { + sparkle, err := domain.NewSparkle("@tanuki for testing this func") + require.NoError(t, err) + require.NoError(t, storage.Save(sparkle)) + + result := storage.Find(sparkle.ID) + require.NotNil(t, result) + require.Equal(t, sparkle, result) + }) + + t.Run("when the entity does not exist", func(t *testing.T) { + result := storage.Find("unknown") + require.Nil(t, result) + }) + }) +} diff --git a/pkg/db/repository.go b/pkg/db/repository.go index 79c7ae3..397eee7 100644 --- a/pkg/db/repository.go +++ b/pkg/db/repository.go @@ -1,41 +1,7 @@ package db -import "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls" - -type Entity interface { - GetID() string - SetID(id string) error - Validate() error -} - type Repository[T Entity] interface { All() []T + Find(string) T Save(T) error } - -type inMemoryRepository[T Entity] struct { - items []T -} - -func NewRepository[T Entity]() Repository[T] { - return &inMemoryRepository[T]{ - items: []T{}, - } -} - -func (r *inMemoryRepository[T]) All() []T { - return r.items -} - -func (r *inMemoryRepository[T]) Save(item T) error { - if err := item.Validate(); err != nil { - return err - } - - if item.GetID() == "" { - item.SetID(pls.GenerateULID()) - } - - r.items = append(r.items, item) - return nil -} diff --git a/pkg/db/repository_test.go b/pkg/db/repository_test.go deleted file mode 100644 index bb788d2..0000000 --- a/pkg/db/repository_test.go +++ /dev/null @@ -1,33 +0,0 @@ -package db - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" -) - -func TestRepository(t *testing.T) { - storage := NewRepository[*domain.Sparkle]() - - t.Run("Save", func(t *testing.T) { - t.Run("an invalid Sparkle", func(t *testing.T) { - err := storage.Save(&domain.Sparkle{Reason: "because"}) - - assert.Error(t, err) - assert.Equal(t, 0, len(storage.All())) - }) - - t.Run("a valid Sparkle", func(t *testing.T) { - sparkle := &domain.Sparkle{Sparklee: "@tanuki", Reason: "because"} - require.NoError(t, storage.Save(sparkle)) - - sparkles := storage.All() - assert.Equal(t, 1, len(sparkles)) - assert.NotEmpty(t, sparkles[0].ID) - assert.Equal(t, "@tanuki", sparkles[0].Sparklee) - assert.Equal(t, "because", sparkles[0].Reason) - }) - }) -} diff --git a/pkg/domain/user.go b/pkg/domain/user.go index ed06dc1..4053e8f 100644 --- a/pkg/domain/user.go +++ b/pkg/domain/user.go @@ -7,3 +7,16 @@ type User struct { func NewUser() *User { return &User{} } + +func (s *User) GetID() string { + return s.ID +} + +func (s *User) SetID(id string) error { + s.ID = id + return nil +} + +func (s *User) Validate() error { + return nil +} diff --git a/pkg/web/middleware/user.go b/pkg/web/middleware/user.go index 9dc1a1f..b01ae48 100644 --- a/pkg/web/middleware/user.go +++ b/pkg/web/middleware/user.go @@ -1,11 +1,29 @@ package middleware -import "net/http" +import ( + "net/http" -func User() func(http.Handler) http.Handler { + "github.com/xlgmokha/x/pkg/x" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/db" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key" +) + +func User(db db.Repository[*domain.User]) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) + idToken := key.IDToken.From(r.Context()) + if x.IsZero(idToken) { + next.ServeHTTP(w, r) + return + } + + user := db.Find(idToken.Subject) + if x.IsZero(user) { + next.ServeHTTP(w, r) + } else { + next.ServeHTTP(w, r.WithContext(key.CurrentUser.With(r.Context(), user))) + } }) } } diff --git a/pkg/web/middleware/user_test.go b/pkg/web/middleware/user_test.go index 7119b41..cde7dec 100644 --- a/pkg/web/middleware/user_test.go +++ b/pkg/web/middleware/user_test.go @@ -6,16 +6,36 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/db" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test" ) func TestUser(t *testing.T) { - middleware := User() + repository := db.NewRepository[*domain.User]() + middleware := User(repository) + + knownUser := &domain.User{ID: "1"} + require.NoError(t, repository.Save(knownUser)) t.Run("when an ID Token is found in the context", func(t *testing.T) { - t.Run("When the user is found in the db", func(t *testing.T) { + t.Run("when the user is found in the db", func(t *testing.T) { + server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user := key.CurrentUser.From(r.Context()) + require.NotNil(t, user) + + w.WriteHeader(http.StatusTeapot) + })) + + idToken := &oidc.IDToken{Subject: knownUser.ID} + ctx := key.IDToken.With(t.Context(), idToken) + + r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx)) + server.ServeHTTP(w, r) + assert.Equal(t, http.StatusTeapot, w.Code) }) t.Run("when the user is not found in the db", func(t *testing.T) { -- cgit v1.2.3