From c82468b1b32ad5bfb347fe65cd5dcfb6680795d1 Mon Sep 17 00:00:00 2001 From: mo khan Date: Wed, 7 May 2025 10:47:31 -0700 Subject: refactor: provide context to repository to apply timeout --- app/db/in_memory_repository.go | 9 +++++---- app/db/in_memory_repository_test.go | 14 +++++++------- 2 files changed, 12 insertions(+), 11 deletions(-) (limited to 'app/db') diff --git a/app/db/in_memory_repository.go b/app/db/in_memory_repository.go index 5d8628d..ba9ebad 100644 --- a/app/db/in_memory_repository.go +++ b/app/db/in_memory_repository.go @@ -1,6 +1,7 @@ package db import ( + "context" "sort" "github.com/xlgmokha/x/pkg/x" @@ -18,17 +19,17 @@ func NewRepository[T domain.Entity]() domain.Repository[T] { } } -func (r *inMemoryRepository[T]) All() []T { +func (r *inMemoryRepository[T]) All(ctx context.Context) []T { return r.items } -func (r *inMemoryRepository[T]) Find(id domain.ID) T { - return x.Find(r.All(), func(item T) bool { +func (r *inMemoryRepository[T]) Find(ctx context.Context, id domain.ID) T { + return x.Find(r.All(ctx), func(item T) bool { return item.GetID() == id }) } -func (r *inMemoryRepository[T]) Save(item T) error { +func (r *inMemoryRepository[T]) Save(ctx context.Context, item T) error { if err := item.Validate(); err != nil { return err } diff --git a/app/db/in_memory_repository_test.go b/app/db/in_memory_repository_test.go index bd9d12f..cf516aa 100644 --- a/app/db/in_memory_repository_test.go +++ b/app/db/in_memory_repository_test.go @@ -13,17 +13,17 @@ func TestInMemoryRepository(t *testing.T) { t.Run("Save", func(t *testing.T) { t.Run("an invalid Sparkle", func(t *testing.T) { - err := storage.Save(&domain.Sparkle{Reason: "because"}) + err := storage.Save(t.Context(), &domain.Sparkle{Reason: "because"}) assert.Error(t, err) - assert.Equal(t, 0, len(storage.All())) + assert.Equal(t, 0, len(storage.All(t.Context()))) }) t.Run("a valid Sparkle", func(t *testing.T) { sparkle := &domain.Sparkle{Sparklee: "@tanuki", Reason: "because"} - require.NoError(t, storage.Save(sparkle)) + require.NoError(t, storage.Save(t.Context(), sparkle)) - sparkles := storage.All() + sparkles := storage.All(t.Context()) assert.Equal(t, 1, len(sparkles)) assert.NotEmpty(t, sparkles[0].ID) assert.Equal(t, "@tanuki", sparkles[0].Sparklee) @@ -35,15 +35,15 @@ func TestInMemoryRepository(t *testing.T) { t.Run("when the entity exists", func(t *testing.T) { sparkle, err := domain.NewSparkle("@tanuki for testing this func") require.NoError(t, err) - require.NoError(t, storage.Save(sparkle)) + require.NoError(t, storage.Save(t.Context(), sparkle)) - result := storage.Find(sparkle.ID) + result := storage.Find(t.Context(), sparkle.ID) require.NotNil(t, result) require.Equal(t, sparkle, result) }) t.Run("when the entity does not exist", func(t *testing.T) { - result := storage.Find("unknown") + result := storage.Find(t.Context(), "unknown") require.Nil(t, result) }) }) -- cgit v1.2.3