diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-13 14:10:16 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-13 14:10:16 -0600 |
| commit | d968a303e60ae6de25f0587b0b00008c332c5fb5 (patch) | |
| tree | f1d9f124ad732eeaecefbc96b639d9260f008b53 /app | |
| parent | 8dff2917704440c31bc4d28a4c1e763709b268ce (diff) | |
test: reproduce a race condition
Diffstat (limited to 'app')
| -rw-r--r-- | app/db/in_memory_repository_test.go | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/app/db/in_memory_repository_test.go b/app/db/in_memory_repository_test.go index cf516aa..37fcbe0 100644 --- a/app/db/in_memory_repository_test.go +++ b/app/db/in_memory_repository_test.go @@ -1,6 +1,8 @@ package db import ( + "context" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -29,6 +31,66 @@ func TestInMemoryRepository(t *testing.T) { assert.Equal(t, "@tanuki", sparkles[0].Sparklee) assert.Equal(t, "because", sparkles[0].Reason) }) + + t.Run("prevents race conditions", func(t *testing.T) { + repository := NewRepository[*domain.Sparkle]() + ctx := context.Background() + + numGoroutines := 100 + numOperationsPerGoroutine := 50 + + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*2) + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(workerID int) { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + errors <- assert.AnError + } + }() + + for j := 0; j < numOperationsPerGoroutine; j++ { + if err := repository.Save(ctx, &domain.Sparkle{ + Sparklee: "@user" + string(rune(workerID)), + Reason: "for running concurrently", + }); err != nil { + errors <- err + return + } + } + }(i) + } + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + defer func() { + if r := recover(); r != nil { + errors <- assert.AnError + } + }() + + for j := 0; j < numOperationsPerGoroutine; j++ { + _ = repository.All(ctx) + } + }() + } + + wg.Wait() + close(errors) + + var raceErrors []error + for err := range errors { + raceErrors = append(raceErrors, err) + } + + assert.Equal(t, numGoroutines*numOperationsPerGoroutine, len(repository.All(ctx))) + assert.Empty(t, raceErrors) + }) }) t.Run("Find", func(t *testing.T) { |
