summaryrefslogtreecommitdiff
path: root/app
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-05-07 10:47:31 -0700
committermo khan <mo@mokhan.ca>2025-05-07 10:47:31 -0700
commitc82468b1b32ad5bfb347fe65cd5dcfb6680795d1 (patch)
tree679b4728fdae2ed296730a49f5100ddcf3c25f98 /app
parentf0fbdab72254d68d0a3a4a49a4a1646f89f0f913 (diff)
refactor: provide context to repository to apply timeout
Diffstat (limited to 'app')
-rw-r--r--app/controllers/sparkles/controller.go6
-rw-r--r--app/controllers/sparkles/controller_test.go6
-rw-r--r--app/db/in_memory_repository.go9
-rw-r--r--app/db/in_memory_repository_test.go14
-rw-r--r--app/domain/repository.go8
-rw-r--r--app/middleware/user.go4
-rw-r--r--app/middleware/user_test.go4
7 files changed, 27 insertions, 24 deletions
diff --git a/app/controllers/sparkles/controller.go b/app/controllers/sparkles/controller.go
index 4963950..cd86cd2 100644
--- a/app/controllers/sparkles/controller.go
+++ b/app/controllers/sparkles/controller.go
@@ -31,7 +31,7 @@ func (c *Controller) MountTo(mux *http.ServeMux) {
}
func (c *Controller) Index(w http.ResponseWriter, r *http.Request) {
- if err := serde.ToHTTP(w, r, c.db.All()); err != nil {
+ if err := serde.ToHTTP(w, r, c.db.All(r.Context())); err != nil {
pls.LogError(r.Context(), err)
w.WriteHeader(http.StatusInternalServerError)
}
@@ -45,7 +45,7 @@ func (c *Controller) Create(w http.ResponseWriter, r *http.Request) {
return
}
- if err := c.db.Save(sparkle); err != nil {
+ if err := c.db.Save(r.Context(), sparkle); err != nil {
pls.LogError(r.Context(), err)
w.WriteHeader(http.StatusBadRequest)
return
@@ -66,7 +66,7 @@ func (c *Controller) Restore(w http.ResponseWriter, r *http.Request) {
log.WithFields(r.Context(), log.Fields{"sparkles": sparkles})
x.Each(sparkles, func(sparkle *domain.Sparkle) {
- if err := c.db.Save(sparkle); err != nil {
+ if err := c.db.Save(r.Context(), sparkle); err != nil {
pls.LogError(r.Context(), err)
}
})
diff --git a/app/controllers/sparkles/controller_test.go b/app/controllers/sparkles/controller_test.go
index b2c7752..8a1717d 100644
--- a/app/controllers/sparkles/controller_test.go
+++ b/app/controllers/sparkles/controller_test.go
@@ -17,7 +17,7 @@ func TestSparkles(t *testing.T) {
t.Run("GET /sparkles", func(t *testing.T) {
sparkle, _ := domain.NewSparkle("@tanuki for helping me")
store := db.NewRepository[*domain.Sparkle]()
- store.Save(sparkle)
+ store.Save(t.Context(), sparkle)
mux := http.NewServeMux()
controller := New(store)
@@ -75,8 +75,8 @@ func TestSparkles(t *testing.T) {
})
t.Run("saves the sparkle to the db", func(t *testing.T) {
- assert.Equal(t, 1, len(repository.All()))
- item := repository.All()[0]
+ assert.Equal(t, 1, len(repository.All(t.Context())))
+ item := repository.All(t.Context())[0]
assert.Equal(t, "@tanuki", item.Sparklee)
assert.Equal(t, "for reviewing my code!", item.Reason)
diff --git a/app/db/in_memory_repository.go b/app/db/in_memory_repository.go
index 5d8628d..ba9ebad 100644
--- a/app/db/in_memory_repository.go
+++ b/app/db/in_memory_repository.go
@@ -1,6 +1,7 @@
package db
import (
+ "context"
"sort"
"github.com/xlgmokha/x/pkg/x"
@@ -18,17 +19,17 @@ func NewRepository[T domain.Entity]() domain.Repository[T] {
}
}
-func (r *inMemoryRepository[T]) All() []T {
+func (r *inMemoryRepository[T]) All(ctx context.Context) []T {
return r.items
}
-func (r *inMemoryRepository[T]) Find(id domain.ID) T {
- return x.Find(r.All(), func(item T) bool {
+func (r *inMemoryRepository[T]) Find(ctx context.Context, id domain.ID) T {
+ return x.Find(r.All(ctx), func(item T) bool {
return item.GetID() == id
})
}
-func (r *inMemoryRepository[T]) Save(item T) error {
+func (r *inMemoryRepository[T]) Save(ctx context.Context, item T) error {
if err := item.Validate(); err != nil {
return err
}
diff --git a/app/db/in_memory_repository_test.go b/app/db/in_memory_repository_test.go
index bd9d12f..cf516aa 100644
--- a/app/db/in_memory_repository_test.go
+++ b/app/db/in_memory_repository_test.go
@@ -13,17 +13,17 @@ func TestInMemoryRepository(t *testing.T) {
t.Run("Save", func(t *testing.T) {
t.Run("an invalid Sparkle", func(t *testing.T) {
- err := storage.Save(&domain.Sparkle{Reason: "because"})
+ err := storage.Save(t.Context(), &domain.Sparkle{Reason: "because"})
assert.Error(t, err)
- assert.Equal(t, 0, len(storage.All()))
+ assert.Equal(t, 0, len(storage.All(t.Context())))
})
t.Run("a valid Sparkle", func(t *testing.T) {
sparkle := &domain.Sparkle{Sparklee: "@tanuki", Reason: "because"}
- require.NoError(t, storage.Save(sparkle))
+ require.NoError(t, storage.Save(t.Context(), sparkle))
- sparkles := storage.All()
+ sparkles := storage.All(t.Context())
assert.Equal(t, 1, len(sparkles))
assert.NotEmpty(t, sparkles[0].ID)
assert.Equal(t, "@tanuki", sparkles[0].Sparklee)
@@ -35,15 +35,15 @@ func TestInMemoryRepository(t *testing.T) {
t.Run("when the entity exists", func(t *testing.T) {
sparkle, err := domain.NewSparkle("@tanuki for testing this func")
require.NoError(t, err)
- require.NoError(t, storage.Save(sparkle))
+ require.NoError(t, storage.Save(t.Context(), sparkle))
- result := storage.Find(sparkle.ID)
+ result := storage.Find(t.Context(), sparkle.ID)
require.NotNil(t, result)
require.Equal(t, sparkle, result)
})
t.Run("when the entity does not exist", func(t *testing.T) {
- result := storage.Find("unknown")
+ result := storage.Find(t.Context(), "unknown")
require.Nil(t, result)
})
})
diff --git a/app/domain/repository.go b/app/domain/repository.go
index fb7b6da..ae9da9e 100644
--- a/app/domain/repository.go
+++ b/app/domain/repository.go
@@ -1,7 +1,9 @@
package domain
+import "context"
+
type Repository[T Entity] interface {
- All() []T
- Find(ID) T
- Save(T) error
+ All(context.Context) []T
+ Find(context.Context, ID) T
+ Save(context.Context, T) error
}
diff --git a/app/middleware/user.go b/app/middleware/user.go
index 03c04d6..c0181f9 100644
--- a/app/middleware/user.go
+++ b/app/middleware/user.go
@@ -20,10 +20,10 @@ func User(db domain.Repository[*domain.User]) func(http.Handler) http.Handler {
return
}
- user := db.Find(domain.ID(idToken.Subject))
+ user := db.Find(r.Context(), domain.ID(idToken.Subject))
if !x.IsPresent(user) {
user = mapper.MapFrom[*oidc.IDToken, *domain.User](idToken)
- if err := db.Save(user); err != nil {
+ if err := db.Save(r.Context(), user); err != nil {
pls.LogError(r.Context(), err)
next.ServeHTTP(w, r)
return
diff --git a/app/middleware/user_test.go b/app/middleware/user_test.go
index e1bbcf9..e6ba09d 100644
--- a/app/middleware/user_test.go
+++ b/app/middleware/user_test.go
@@ -19,7 +19,7 @@ func TestUser(t *testing.T) {
middleware := User(repository)
knownUser := &domain.User{ID: domain.ID(pls.GenerateULID())}
- require.NoError(t, repository.Save(knownUser))
+ require.NoError(t, repository.Save(t.Context(), knownUser))
t.Run("when ID Token is provided", func(t *testing.T) {
t.Run("when user is known", func(t *testing.T) {
@@ -56,7 +56,7 @@ func TestUser(t *testing.T) {
server.ServeHTTP(w, r)
assert.Equal(t, http.StatusTeapot, w.Code)
- require.NotNil(t, repository.Find(domain.ID(unknownID)))
+ require.NotNil(t, repository.Find(t.Context(), domain.ID(unknownID)))
})
})