diff options
Diffstat (limited to 'pkg/web')
| -rw-r--r-- | pkg/web/middleware/init.go | 11 | ||||
| -rw-r--r-- | pkg/web/middleware/user.go | 2 | ||||
| -rw-r--r-- | pkg/web/middleware/user_test.go | 8 |
3 files changed, 15 insertions, 6 deletions
diff --git a/pkg/web/middleware/init.go b/pkg/web/middleware/init.go index ccf4836..ac06c32 100644 --- a/pkg/web/middleware/init.go +++ b/pkg/web/middleware/init.go @@ -8,8 +8,17 @@ import ( func init() { mapper.Register(func(idToken *oidc.IDToken) *domain.User { + customClaims := &oidc.CustomClaims{} + if err := idToken.Claims(customClaims); err != nil { + return &domain.User{ID: domain.ID(idToken.Subject)} + } + return &domain.User{ - ID: idToken.Subject, + ID: domain.ID(idToken.Subject), + Username: customClaims.Nickname, + Email: customClaims.Email, + ProfileURL: customClaims.ProfileURL, + Picture: customClaims.Picture, } }) } diff --git a/pkg/web/middleware/user.go b/pkg/web/middleware/user.go index 1e95ce0..68d2daa 100644 --- a/pkg/web/middleware/user.go +++ b/pkg/web/middleware/user.go @@ -21,7 +21,7 @@ func User(db db.Repository[*domain.User]) func(http.Handler) http.Handler { return } - user := db.Find(idToken.Subject) + user := db.Find(domain.ID(idToken.Subject)) if x.IsZero(user) { user = mapper.MapFrom[*oidc.IDToken, *domain.User](idToken) if err := db.Save(user); err != nil { diff --git a/pkg/web/middleware/user_test.go b/pkg/web/middleware/user_test.go index 447d54d..c18bfdb 100644 --- a/pkg/web/middleware/user_test.go +++ b/pkg/web/middleware/user_test.go @@ -18,7 +18,7 @@ func TestUser(t *testing.T) { repository := db.NewRepository[*domain.User]() middleware := User(repository) - knownUser := &domain.User{ID: pls.GenerateULID()} + knownUser := &domain.User{ID: domain.ID(pls.GenerateULID())} require.NoError(t, repository.Save(knownUser)) t.Run("when ID Token is provided", func(t *testing.T) { @@ -31,7 +31,7 @@ func TestUser(t *testing.T) { w.WriteHeader(http.StatusTeapot) })) - ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID}) + ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID.String()}) r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx)) server.ServeHTTP(w, r) @@ -45,7 +45,7 @@ func TestUser(t *testing.T) { server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { user := key.CurrentUser.From(r.Context()) require.NotNil(t, user) - assert.Equal(t, unknownID, user.ID) + assert.Equal(t, domain.ID(unknownID), user.ID) w.WriteHeader(http.StatusTeapot) })) @@ -56,7 +56,7 @@ func TestUser(t *testing.T) { server.ServeHTTP(w, r) assert.Equal(t, http.StatusTeapot, w.Code) - require.NotNil(t, repository.Find(unknownID)) + require.NotNil(t, repository.Find(domain.ID(unknownID))) }) }) |
