summaryrefslogtreecommitdiff
path: root/app
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-13 14:10:16 -0600
committermo khan <mo@mokhan.ca>2025-06-13 14:10:16 -0600
commitd968a303e60ae6de25f0587b0b00008c332c5fb5 (patch)
treef1d9f124ad732eeaecefbc96b639d9260f008b53 /app
parent8dff2917704440c31bc4d28a4c1e763709b268ce (diff)
test: reproduce a race condition
Diffstat (limited to 'app')
-rw-r--r--app/db/in_memory_repository_test.go62
1 files changed, 62 insertions, 0 deletions
diff --git a/app/db/in_memory_repository_test.go b/app/db/in_memory_repository_test.go
index cf516aa..37fcbe0 100644
--- a/app/db/in_memory_repository_test.go
+++ b/app/db/in_memory_repository_test.go
@@ -1,6 +1,8 @@
package db
import (
+ "context"
+ "sync"
"testing"
"github.com/stretchr/testify/assert"
@@ -29,6 +31,66 @@ func TestInMemoryRepository(t *testing.T) {
assert.Equal(t, "@tanuki", sparkles[0].Sparklee)
assert.Equal(t, "because", sparkles[0].Reason)
})
+
+ t.Run("prevents race conditions", func(t *testing.T) {
+ repository := NewRepository[*domain.Sparkle]()
+ ctx := context.Background()
+
+ numGoroutines := 100
+ numOperationsPerGoroutine := 50
+
+ var wg sync.WaitGroup
+ errors := make(chan error, numGoroutines*2)
+
+ for i := 0; i < numGoroutines; i++ {
+ wg.Add(1)
+ go func(workerID int) {
+ defer wg.Done()
+ defer func() {
+ if r := recover(); r != nil {
+ errors <- assert.AnError
+ }
+ }()
+
+ for j := 0; j < numOperationsPerGoroutine; j++ {
+ if err := repository.Save(ctx, &domain.Sparkle{
+ Sparklee: "@user" + string(rune(workerID)),
+ Reason: "for running concurrently",
+ }); err != nil {
+ errors <- err
+ return
+ }
+ }
+ }(i)
+ }
+
+ for i := 0; i < numGoroutines; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ defer func() {
+ if r := recover(); r != nil {
+ errors <- assert.AnError
+ }
+ }()
+
+ for j := 0; j < numOperationsPerGoroutine; j++ {
+ _ = repository.All(ctx)
+ }
+ }()
+ }
+
+ wg.Wait()
+ close(errors)
+
+ var raceErrors []error
+ for err := range errors {
+ raceErrors = append(raceErrors, err)
+ }
+
+ assert.Equal(t, numGoroutines*numOperationsPerGoroutine, len(repository.All(ctx)))
+ assert.Empty(t, raceErrors)
+ })
})
t.Run("Find", func(t *testing.T) {