summaryrefslogtreecommitdiff
path: root/app
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-04-25 11:08:58 -0600
committermo khan <mo@mokhan.ca>2025-04-25 11:08:58 -0600
commit2b1e14690ea6426a67c0faaaddcfb8aa7360dce7 (patch)
tree7f764225e3e3a26bbd7532e72ab99a54e465be92 /app
parent0053db0d265af313dd281db5cf1e73236cde30c6 (diff)
refactor: move db and mountable to app
Diffstat (limited to 'app')
-rw-r--r--app/app.go9
-rw-r--r--app/controllers/dashboard/controller.go2
-rw-r--r--app/controllers/sparkles/controller.go2
-rw-r--r--app/controllers/sparkles/controller_test.go2
-rw-r--r--app/db/in_memory_repository.go40
-rw-r--r--app/db/in_memory_repository_test.go50
-rw-r--r--app/init.go2
-rw-r--r--app/middleware/id_token.go56
-rw-r--r--app/middleware/id_token_test.go101
-rw-r--r--app/middleware/init.go24
-rw-r--r--app/middleware/require_user.go22
-rw-r--r--app/middleware/require_user_test.go43
-rw-r--r--app/middleware/user.go36
-rw-r--r--app/middleware/user_test.go76
14 files changed, 458 insertions, 7 deletions
diff --git a/app/app.go b/app/app.go
index 95cd908..80ab9ce 100644
--- a/app/app.go
+++ b/app/app.go
@@ -12,15 +12,18 @@ import (
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/controllers/sessions"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/controllers/sparkles"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/middleware"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web/middleware"
)
+type Mountable interface {
+ MountTo(*http.ServeMux)
+}
+
func New(rootDir string) http.Handler {
mux := ioc.MustResolve[*http.ServeMux](ioc.Default)
- mountable := []web.Mountable{
+ mountable := []Mountable{
ioc.MustResolve[*dashboard.Controller](ioc.Default),
ioc.MustResolve[*health.Controller](ioc.Default),
ioc.MustResolve[*sessions.Controller](ioc.Default),
diff --git a/app/controllers/dashboard/controller.go b/app/controllers/dashboard/controller.go
index 65b2fe5..a1d1bbf 100644
--- a/app/controllers/dashboard/controller.go
+++ b/app/controllers/dashboard/controller.go
@@ -4,9 +4,9 @@ import (
"net/http"
"github.com/xlgmokha/x/pkg/log"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/middleware"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/views"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web/middleware"
)
type Controller struct {
diff --git a/app/controllers/sparkles/controller.go b/app/controllers/sparkles/controller.go
index 5cdb60d..9c319b2 100644
--- a/app/controllers/sparkles/controller.go
+++ b/app/controllers/sparkles/controller.go
@@ -8,7 +8,7 @@ import (
"github.com/xlgmokha/x/pkg/serde"
"github.com/xlgmokha/x/pkg/x"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web/middleware"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/middleware"
)
type Controller struct {
diff --git a/app/controllers/sparkles/controller_test.go b/app/controllers/sparkles/controller_test.go
index 65a9622..21f4ec7 100644
--- a/app/controllers/sparkles/controller_test.go
+++ b/app/controllers/sparkles/controller_test.go
@@ -7,8 +7,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/xlgmokha/x/pkg/serde"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/db"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/db"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test"
)
diff --git a/app/db/in_memory_repository.go b/app/db/in_memory_repository.go
new file mode 100644
index 0000000..5b84dbf
--- /dev/null
+++ b/app/db/in_memory_repository.go
@@ -0,0 +1,40 @@
+package db
+
+import (
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls"
+)
+
+type inMemoryRepository[T domain.Entity] struct {
+ items []T
+}
+
+func NewRepository[T domain.Entity]() domain.Repository[T] {
+ return &inMemoryRepository[T]{
+ items: []T{},
+ }
+}
+
+func (r *inMemoryRepository[T]) All() []T {
+ return r.items
+}
+
+func (r *inMemoryRepository[T]) Find(id domain.ID) T {
+ return x.Find(r.All(), func(item T) bool {
+ return item.GetID() == id
+ })
+}
+
+func (r *inMemoryRepository[T]) Save(item T) error {
+ if err := item.Validate(); err != nil {
+ return err
+ }
+
+ if item.GetID() == "" {
+ item.SetID(domain.ID(pls.GenerateULID()))
+ }
+
+ r.items = append(r.items, item)
+ return nil
+}
diff --git a/app/db/in_memory_repository_test.go b/app/db/in_memory_repository_test.go
new file mode 100644
index 0000000..bd9d12f
--- /dev/null
+++ b/app/db/in_memory_repository_test.go
@@ -0,0 +1,50 @@
+package db
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
+)
+
+func TestInMemoryRepository(t *testing.T) {
+ storage := NewRepository[*domain.Sparkle]()
+
+ t.Run("Save", func(t *testing.T) {
+ t.Run("an invalid Sparkle", func(t *testing.T) {
+ err := storage.Save(&domain.Sparkle{Reason: "because"})
+
+ assert.Error(t, err)
+ assert.Equal(t, 0, len(storage.All()))
+ })
+
+ t.Run("a valid Sparkle", func(t *testing.T) {
+ sparkle := &domain.Sparkle{Sparklee: "@tanuki", Reason: "because"}
+ require.NoError(t, storage.Save(sparkle))
+
+ sparkles := storage.All()
+ assert.Equal(t, 1, len(sparkles))
+ assert.NotEmpty(t, sparkles[0].ID)
+ assert.Equal(t, "@tanuki", sparkles[0].Sparklee)
+ assert.Equal(t, "because", sparkles[0].Reason)
+ })
+ })
+
+ t.Run("Find", func(t *testing.T) {
+ t.Run("when the entity exists", func(t *testing.T) {
+ sparkle, err := domain.NewSparkle("@tanuki for testing this func")
+ require.NoError(t, err)
+ require.NoError(t, storage.Save(sparkle))
+
+ result := storage.Find(sparkle.ID)
+ require.NotNil(t, result)
+ require.Equal(t, sparkle, result)
+ })
+
+ t.Run("when the entity does not exist", func(t *testing.T) {
+ result := storage.Find("unknown")
+ require.Nil(t, result)
+ })
+ })
+}
diff --git a/app/init.go b/app/init.go
index a42d2f7..968303b 100644
--- a/app/init.go
+++ b/app/init.go
@@ -13,8 +13,8 @@ import (
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/controllers/health"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/controllers/sessions"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/controllers/sparkles"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/db"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
- "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/db"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web"
"golang.org/x/oauth2"
diff --git a/app/middleware/id_token.go b/app/middleware/id_token.go
new file mode 100644
index 0000000..a32c77b
--- /dev/null
+++ b/app/middleware/id_token.go
@@ -0,0 +1,56 @@
+package middleware
+
+import (
+ "net/http"
+
+ "github.com/xlgmokha/x/pkg/log"
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc"
+)
+
+type TokenParser func(*http.Request) oidc.RawToken
+
+func IDTokenFromSessionCookie(r *http.Request) oidc.RawToken {
+ cookies := r.CookiesNamed("session")
+
+ if len(cookies) != 1 {
+ return ""
+ }
+
+ tokens, err := oidc.TokensFromBase64String(cookies[0].Value)
+ if err != nil {
+ log.WithFields(r.Context(), log.Fields{"error": err})
+ return ""
+ }
+
+ return tokens.IDToken
+}
+
+func IDToken(cfg *oidc.OpenID) func(http.Handler) http.Handler {
+ parsers := []TokenParser{IDTokenFromSessionCookie}
+
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ for _, parser := range parsers {
+ rawIDToken := parser(r)
+ if !x.IsZero(rawIDToken) {
+ verifier := cfg.Provider.VerifierContext(r.Context(), cfg.OIDCConfig)
+ idToken, err := verifier.Verify(r.Context(), rawIDToken.String())
+ if err != nil {
+ log.WithFields(r.Context(), log.Fields{"error": err})
+ } else {
+ log.WithFields(r.Context(), log.Fields{"id_token": idToken})
+ next.ServeHTTP(
+ w,
+ r.WithContext(key.IDToken.With(r.Context(), idToken)),
+ )
+ return
+ }
+ }
+ }
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
diff --git a/app/middleware/id_token_test.go b/app/middleware/id_token_test.go
new file mode 100644
index 0000000..4f26cdf
--- /dev/null
+++ b/app/middleware/id_token_test.go
@@ -0,0 +1,101 @@
+package middleware
+
+import (
+ "context"
+ "net/http"
+ "os"
+ "testing"
+ "time"
+
+ "github.com/oauth2-proxy/mockoidc"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/xlgmokha/x/pkg/log"
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/web/cookie"
+ "golang.org/x/oauth2"
+)
+
+func TestIDToken(t *testing.T) {
+ srv := test.NewOIDCServer(t)
+ defer srv.Close()
+
+ client := &http.Client{Transport: &web.Transport{Logger: log.New(os.Stdout, log.Fields{})}}
+ cfg := srv.MockOIDC.Config()
+ ctx := context.WithValue(t.Context(), oauth2.HTTPClient, client)
+ openID, err := oidc.New(
+ ctx,
+ srv.Issuer(),
+ cfg.ClientID,
+ cfg.ClientSecret,
+ "https://example.com/oauth/callback",
+ )
+ require.NoError(t, err)
+
+ middleware := IDToken(openID)
+
+ t.Run("when an active session cookie is provided", func(t *testing.T) {
+ t.Run("attaches the token to the request context", func(t *testing.T) {
+ user := mockoidc.DefaultUser()
+
+ token, rawIDToken := srv.CreateTokensFor(user)
+ tokens := &oidc.Tokens{Token: token, IDToken: oidc.RawToken(rawIDToken)}
+ encoded := x.Must(tokens.ToBase64String())
+
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ token := key.IDToken.From(r.Context())
+ require.NotNil(t, token)
+ assert.Equal(t, user.Subject, token.Subject)
+
+ w.WriteHeader(http.StatusTeapot)
+ }))
+
+ r, w := test.RequestResponse(
+ "GET",
+ "/example",
+ test.WithCookie(cookie.New("session", encoded, time.Now().Add(1*time.Hour))),
+ )
+ server.ServeHTTP(w, r)
+
+ assert.Equal(t, http.StatusTeapot, w.Code)
+ })
+ })
+
+ t.Run("when an invalid session cookie is provided", func(t *testing.T) {
+ t.Run("forwards the request", func(t *testing.T) {
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Nil(t, key.IDToken.From(r.Context()))
+
+ w.WriteHeader(http.StatusTeapot)
+ }))
+
+ r, w := test.RequestResponse(
+ "GET",
+ "/example",
+ test.WithCookie(cookie.New("session", "invalid", time.Now().Add(1*time.Hour))),
+ )
+ server.ServeHTTP(w, r)
+
+ assert.Equal(t, http.StatusTeapot, w.Code)
+ })
+ })
+
+ t.Run("when no cookies are provided", func(t *testing.T) {
+ t.Run("forwards the request", func(t *testing.T) {
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Nil(t, key.IDToken.From(r.Context()))
+
+ w.WriteHeader(http.StatusTeapot)
+ }))
+
+ r, w := test.RequestResponse("GET", "/example")
+ server.ServeHTTP(w, r)
+
+ assert.Equal(t, http.StatusTeapot, w.Code)
+ })
+ })
+}
diff --git a/app/middleware/init.go b/app/middleware/init.go
new file mode 100644
index 0000000..f1a693d
--- /dev/null
+++ b/app/middleware/init.go
@@ -0,0 +1,24 @@
+package middleware
+
+import (
+ "github.com/xlgmokha/x/pkg/mapper"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc"
+)
+
+func init() {
+ mapper.Register(func(idToken *oidc.IDToken) *domain.User {
+ customClaims := &oidc.CustomClaims{}
+ if err := idToken.Claims(customClaims); err != nil {
+ return &domain.User{ID: domain.ID(idToken.Subject)}
+ }
+
+ return &domain.User{
+ ID: domain.ID(idToken.Subject),
+ Username: customClaims.Nickname,
+ Email: customClaims.Email,
+ ProfileURL: customClaims.ProfileURL,
+ Picture: customClaims.Picture,
+ }
+ })
+}
diff --git a/app/middleware/require_user.go b/app/middleware/require_user.go
new file mode 100644
index 0000000..e81d5b5
--- /dev/null
+++ b/app/middleware/require_user.go
@@ -0,0 +1,22 @@
+package middleware
+
+import (
+ "net/http"
+
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
+)
+
+func RequireUser(code int, url string) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ user := key.CurrentUser.From(r.Context())
+ if x.IsZero(user) {
+ http.Redirect(w, r, url, code)
+ return
+ }
+
+ next.ServeHTTP(w, r)
+ })
+ }
+}
diff --git a/app/middleware/require_user_test.go b/app/middleware/require_user_test.go
new file mode 100644
index 0000000..68b9911
--- /dev/null
+++ b/app/middleware/require_user_test.go
@@ -0,0 +1,43 @@
+package middleware
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test"
+)
+
+func TestRequireUser(t *testing.T) {
+ middleware := RequireUser(http.StatusFound, "/login")
+
+ t.Run("when a user is not logged in", func(t *testing.T) {
+ t.Run("redirects to the homepage", func(t *testing.T) {
+ r, w := test.RequestResponse("GET", "/example")
+
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ require.Fail(t, "unexpected call to handler")
+ }))
+ server.ServeHTTP(w, r)
+
+ require.Equal(t, http.StatusFound, w.Code)
+ assert.Equal(t, "/login", w.Header().Get("Location"))
+ })
+ })
+
+ t.Run("when a user is logged in", func(t *testing.T) {
+ t.Run("forwards the request", func(t *testing.T) {
+ r, w := test.RequestResponse("GET", "/example", test.WithContextKeyValue(t.Context(), key.CurrentUser, &domain.User{}))
+
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusTeapot)
+ }))
+ server.ServeHTTP(w, r)
+
+ require.Equal(t, http.StatusTeapot, w.Code)
+ })
+ })
+}
diff --git a/app/middleware/user.go b/app/middleware/user.go
new file mode 100644
index 0000000..194ded6
--- /dev/null
+++ b/app/middleware/user.go
@@ -0,0 +1,36 @@
+package middleware
+
+import (
+ "net/http"
+
+ "github.com/xlgmokha/x/pkg/log"
+ "github.com/xlgmokha/x/pkg/mapper"
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc"
+)
+
+func User(db domain.Repository[*domain.User]) func(http.Handler) http.Handler {
+ return func(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ idToken := key.IDToken.From(r.Context())
+ if x.IsZero(idToken) {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ user := db.Find(domain.ID(idToken.Subject))
+ if x.IsZero(user) {
+ user = mapper.MapFrom[*oidc.IDToken, *domain.User](idToken)
+ if err := db.Save(user); err != nil {
+ log.WithFields(r.Context(), log.Fields{"error": err})
+ next.ServeHTTP(w, r)
+ return
+ }
+ }
+
+ next.ServeHTTP(w, r.WithContext(key.CurrentUser.With(r.Context(), user)))
+ })
+ }
+}
diff --git a/app/middleware/user_test.go b/app/middleware/user_test.go
new file mode 100644
index 0000000..e6c74d8
--- /dev/null
+++ b/app/middleware/user_test.go
@@ -0,0 +1,76 @@
+package middleware
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/db"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/app/domain"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/key"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/oidc"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/test"
+)
+
+func TestUser(t *testing.T) {
+ repository := db.NewRepository[*domain.User]()
+ middleware := User(repository)
+
+ knownUser := &domain.User{ID: domain.ID(pls.GenerateULID())}
+ require.NoError(t, repository.Save(knownUser))
+
+ t.Run("when ID Token is provided", func(t *testing.T) {
+ t.Run("when user is known", func(t *testing.T) {
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ user := key.CurrentUser.From(r.Context())
+ require.NotNil(t, user)
+ assert.Equal(t, knownUser.ID, user.ID)
+
+ w.WriteHeader(http.StatusTeapot)
+ }))
+
+ ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID.String()})
+
+ r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx))
+ server.ServeHTTP(w, r)
+
+ assert.Equal(t, http.StatusTeapot, w.Code)
+ })
+
+ t.Run("when user is unknown", func(t *testing.T) {
+ unknownID := pls.GenerateULID()
+
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ user := key.CurrentUser.From(r.Context())
+ require.NotNil(t, user)
+ assert.Equal(t, domain.ID(unknownID), user.ID)
+
+ w.WriteHeader(http.StatusTeapot)
+ }))
+
+ ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: unknownID})
+
+ r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx))
+ server.ServeHTTP(w, r)
+
+ assert.Equal(t, http.StatusTeapot, w.Code)
+ require.NotNil(t, repository.Find(domain.ID(unknownID)))
+ })
+ })
+
+ t.Run("when ID Token is not provided", func(t *testing.T) {
+ server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ user := key.CurrentUser.From(r.Context())
+ require.Nil(t, user)
+
+ w.WriteHeader(http.StatusTeapot)
+ }))
+
+ r, w := test.RequestResponse("GET", "/example")
+ server.ServeHTTP(w, r)
+
+ assert.Equal(t, http.StatusTeapot, w.Code)
+ })
+}