summaryrefslogtreecommitdiff
path: root/pkg/web
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-04-25 10:10:28 -0600
committermo khan <mo@mokhan.ca>2025-04-25 10:10:28 -0600
commit00b0381dfccab2ddff7de04933fdb11b32695faf (patch)
treedb1e1f0048ba9cccdb415a714cb1344e8f47951f /pkg/web
parent051e2435b255e0995e97d927b73f643149d9b2f3 (diff)
refactor: move id and entity to domain package
Diffstat (limited to 'pkg/web')
-rw-r--r--pkg/web/middleware/init.go11
-rw-r--r--pkg/web/middleware/user.go2
-rw-r--r--pkg/web/middleware/user_test.go8
3 files changed, 15 insertions, 6 deletions
diff --git a/pkg/web/middleware/init.go b/pkg/web/middleware/init.go
index ccf4836..ac06c32 100644
--- a/pkg/web/middleware/init.go
+++ b/pkg/web/middleware/init.go
@@ -8,8 +8,17 @@ import (
func init() {
mapper.Register(func(idToken *oidc.IDToken) *domain.User {
+ customClaims := &oidc.CustomClaims{}
+ if err := idToken.Claims(customClaims); err != nil {
+ return &domain.User{ID: domain.ID(idToken.Subject)}
+ }
+
return &domain.User{
- ID: idToken.Subject,
+ ID: domain.ID(idToken.Subject),
+ Username: customClaims.Nickname,
+ Email: customClaims.Email,
+ ProfileURL: customClaims.ProfileURL,
+ Picture: customClaims.Picture,
}
})
}
diff --git a/pkg/web/middleware/user.go b/pkg/web/middleware/user.go
index 1e95ce0..68d2daa 100644
--- a/pkg/web/middleware/user.go
+++ b/pkg/web/middleware/user.go
@@ -21,7 +21,7 @@ func User(db db.Repository[*domain.User]) func(http.Handler) http.Handler {
return
}
- user := db.Find(idToken.Subject)
+ user := db.Find(domain.ID(idToken.Subject))
if x.IsZero(user) {
user = mapper.MapFrom[*oidc.IDToken, *domain.User](idToken)
if err := db.Save(user); err != nil {
diff --git a/pkg/web/middleware/user_test.go b/pkg/web/middleware/user_test.go
index 447d54d..c18bfdb 100644
--- a/pkg/web/middleware/user_test.go
+++ b/pkg/web/middleware/user_test.go
@@ -18,7 +18,7 @@ func TestUser(t *testing.T) {
repository := db.NewRepository[*domain.User]()
middleware := User(repository)
- knownUser := &domain.User{ID: pls.GenerateULID()}
+ knownUser := &domain.User{ID: domain.ID(pls.GenerateULID())}
require.NoError(t, repository.Save(knownUser))
t.Run("when ID Token is provided", func(t *testing.T) {
@@ -31,7 +31,7 @@ func TestUser(t *testing.T) {
w.WriteHeader(http.StatusTeapot)
}))
- ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID})
+ ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID.String()})
r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx))
server.ServeHTTP(w, r)
@@ -45,7 +45,7 @@ func TestUser(t *testing.T) {
server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := key.CurrentUser.From(r.Context())
require.NotNil(t, user)
- assert.Equal(t, unknownID, user.ID)
+ assert.Equal(t, domain.ID(unknownID), user.ID)
w.WriteHeader(http.StatusTeapot)
}))
@@ -56,7 +56,7 @@ func TestUser(t *testing.T) {
server.ServeHTTP(w, r)
assert.Equal(t, http.StatusTeapot, w.Code)
- require.NotNil(t, repository.Find(unknownID))
+ require.NotNil(t, repository.Find(domain.ID(unknownID)))
})
})