package db import ( "context" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xlgmokha/x/pkg/x" "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain" ) func TestInMemoryRepository(t *testing.T) { storage := NewRepository[*domain.Sparkle]() t.Run("Save", func(t *testing.T) { t.Run("an invalid Sparkle", func(t *testing.T) { assert.Error(t, storage.Save(t.Context(), &domain.Sparkle{Reason: "because"})) assert.Equal(t, 0, len(storage.All(t.Context()))) }) t.Run("a valid Sparkle", func(t *testing.T) { sparkle := &domain.Sparkle{Sparklee: "@tanuki", Reason: "because"} require.NoError(t, storage.Save(t.Context(), sparkle)) sparkles := storage.All(t.Context()) assert.Equal(t, 1, len(sparkles)) assert.NotEmpty(t, sparkles[0].ID) assert.Equal(t, "@tanuki", sparkles[0].Sparklee) assert.Equal(t, "because", sparkles[0].Reason) }) t.Run("prevents race conditions", func(t *testing.T) { repository := NewRepository[*domain.Sparkle]() ctx := context.Background() numGoroutines := 100 numOperationsPerGoroutine := 50 var wg sync.WaitGroup errors := make(chan error, numGoroutines*2) for i := 0; i < numGoroutines; i++ { wg.Add(1) go func(workerID int) { defer wg.Done() defer func() { if r := recover(); r != nil { errors <- assert.AnError } }() for j := 0; j < numOperationsPerGoroutine; j++ { if err := repository.Save(ctx, &domain.Sparkle{ Sparklee: "@user" + string(rune(workerID)), Reason: "for running concurrently", }); err != nil { errors <- err return } } }(i) } for i := 0; i < numGoroutines; i++ { wg.Add(1) go func() { defer wg.Done() defer func() { if r := recover(); r != nil { errors <- assert.AnError } }() for j := 0; j < numOperationsPerGoroutine; j++ { _ = repository.All(ctx) } }() } wg.Wait() close(errors) var raceErrors []error for err := range errors { raceErrors = append(raceErrors, err) } assert.Equal(t, numGoroutines*numOperationsPerGoroutine, len(repository.All(ctx))) assert.Empty(t, raceErrors) }) }) t.Run("All", func(t *testing.T) { repository := NewRepository[*domain.Sparkle]() require.NoError(t, repository.Save(t.Context(), &domain.Sparkle{ Sparklee: "@tanuki", Reason: "because", })) t.Run("returns all the items", func(t *testing.T) { items := repository.All(t.Context()) require.NotNil(t, items) require.Equal(t, 1, len(items)) assert.Equal(t, "@tanuki", items[0].Sparklee) assert.Equal(t, "because", items[0].Reason) }) }) t.Run("Find", func(t *testing.T) { t.Run("when the entity exists", func(t *testing.T) { sparkle := x.New[*domain.Sparkle](domain.WithText("@tanuki for testing this func")) require.NoError(t, storage.Save(t.Context(), sparkle)) result := storage.Find(t.Context(), sparkle.ID) require.NotNil(t, result) require.Equal(t, sparkle, result) }) t.Run("when the entity does not exist", func(t *testing.T) { result := storage.Find(t.Context(), "unknown") require.Nil(t, result) }) }) }