From ab8075f02f50d8bd0be3c23b87e63f10828528ed Mon Sep 17 00:00:00 2001 From: mo khan Date: Mon, 21 Apr 2025 12:31:39 -0600 Subject: refactor: convert Repository to Repository[T Entity] --- pkg/db/repository.go | 41 ++++++++++++++++++++++------------------- pkg/db/repository_test.go | 4 ++-- 2 files changed, 24 insertions(+), 21 deletions(-) (limited to 'pkg/db') diff --git a/pkg/db/repository.go b/pkg/db/repository.go index ef4b9fb..79c7ae3 100644 --- a/pkg/db/repository.go +++ b/pkg/db/repository.go @@ -1,38 +1,41 @@ package db -import ( - "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain" - "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls" -) - -type Repository interface { - All() []*domain.Sparkle - Save(*domain.Sparkle) error +import "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls" + +type Entity interface { + GetID() string + SetID(id string) error + Validate() error +} + +type Repository[T Entity] interface { + All() []T + Save(T) error } -type inMemoryRepository struct { - sparkles []*domain.Sparkle +type inMemoryRepository[T Entity] struct { + items []T } -func NewRepository() Repository { - return &inMemoryRepository{ - sparkles: []*domain.Sparkle{}, +func NewRepository[T Entity]() Repository[T] { + return &inMemoryRepository[T]{ + items: []T{}, } } -func (r *inMemoryRepository) All() []*domain.Sparkle { - return r.sparkles +func (r *inMemoryRepository[T]) All() []T { + return r.items } -func (r *inMemoryRepository) Save(item *domain.Sparkle) error { +func (r *inMemoryRepository[T]) Save(item T) error { if err := item.Validate(); err != nil { return err } - if item.ID == "" { - item.ID = pls.GenerateULID() + if item.GetID() == "" { + item.SetID(pls.GenerateULID()) } - r.sparkles = append(r.sparkles, item) + r.items = append(r.items, item) return nil } diff --git a/pkg/db/repository_test.go b/pkg/db/repository_test.go index 57aee13..bb788d2 100644 --- a/pkg/db/repository_test.go +++ b/pkg/db/repository_test.go @@ -9,13 +9,13 @@ import ( ) func TestRepository(t *testing.T) { - storage := NewRepository() + storage := NewRepository[*domain.Sparkle]() t.Run("Save", func(t *testing.T) { t.Run("an invalid Sparkle", func(t *testing.T) { err := storage.Save(&domain.Sparkle{Reason: "because"}) - assert.NotNil(t, err) + assert.Error(t, err) assert.Equal(t, 0, len(storage.All())) }) -- cgit v1.2.3