summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-04-25 10:10:28 -0600
committermo khan <mo@mokhan.ca>2025-04-25 10:10:28 -0600
commit00b0381dfccab2ddff7de04933fdb11b32695faf (patch)
treedb1e1f0048ba9cccdb415a714cb1344e8f47951f /pkg
parent051e2435b255e0995e97d927b73f643149d9b2f3 (diff)
refactor: move id and entity to domain package
Diffstat (limited to 'pkg')
-rw-r--r--pkg/db/entity.go7
-rw-r--r--pkg/db/in_memory_repository.go9
-rw-r--r--pkg/db/repository.go6
-rw-r--r--pkg/domain/entity.go7
-rw-r--r--pkg/domain/id.go7
-rw-r--r--pkg/domain/sparkle.go14
-rw-r--r--pkg/domain/user.go24
-rw-r--r--pkg/domain/user_test.go24
-rw-r--r--pkg/oidc/id_token.go9
-rw-r--r--pkg/web/middleware/init.go11
-rw-r--r--pkg/web/middleware/user.go2
-rw-r--r--pkg/web/middleware/user_test.go8
12 files changed, 97 insertions, 31 deletions
diff --git a/pkg/db/entity.go b/pkg/db/entity.go
deleted file mode 100644
index 1dcf4c3..0000000
--- a/pkg/db/entity.go
+++ /dev/null
@@ -1,7 +0,0 @@
-package db
-
-type Entity interface {
- GetID() string
- SetID(id string) error
- Validate() error
-}
diff --git a/pkg/db/in_memory_repository.go b/pkg/db/in_memory_repository.go
index 5859c0b..56ca766 100644
--- a/pkg/db/in_memory_repository.go
+++ b/pkg/db/in_memory_repository.go
@@ -2,14 +2,15 @@ package db
import (
"github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain"
"gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/pls"
)
-type inMemoryRepository[T Entity] struct {
+type inMemoryRepository[T domain.Entity] struct {
items []T
}
-func NewRepository[T Entity]() Repository[T] {
+func NewRepository[T domain.Entity]() Repository[T] {
return &inMemoryRepository[T]{
items: []T{},
}
@@ -19,7 +20,7 @@ func (r *inMemoryRepository[T]) All() []T {
return r.items
}
-func (r *inMemoryRepository[T]) Find(id string) T {
+func (r *inMemoryRepository[T]) Find(id domain.ID) T {
return x.Find(r.All(), func(item T) bool {
return item.GetID() == id
})
@@ -31,7 +32,7 @@ func (r *inMemoryRepository[T]) Save(item T) error {
}
if item.GetID() == "" {
- item.SetID(pls.GenerateULID())
+ item.SetID(domain.ID(pls.GenerateULID()))
}
r.items = append(r.items, item)
diff --git a/pkg/db/repository.go b/pkg/db/repository.go
index 397eee7..0ebe216 100644
--- a/pkg/db/repository.go
+++ b/pkg/db/repository.go
@@ -1,7 +1,9 @@
package db
-type Repository[T Entity] interface {
+import "gitlab.com/gitlab-org/software-supply-chain-security/authorization/sparkled/pkg/domain"
+
+type Repository[T domain.Entity] interface {
All() []T
- Find(string) T
+ Find(domain.ID) T
Save(T) error
}
diff --git a/pkg/domain/entity.go b/pkg/domain/entity.go
new file mode 100644
index 0000000..fb1cab8
--- /dev/null
+++ b/pkg/domain/entity.go
@@ -0,0 +1,7 @@
+package domain
+
+type Entity interface {
+ GetID() ID
+ SetID(id ID) error
+ Validate() error
+}
diff --git a/pkg/domain/id.go b/pkg/domain/id.go
new file mode 100644
index 0000000..117a9ad
--- /dev/null
+++ b/pkg/domain/id.go
@@ -0,0 +1,7 @@
+package domain
+
+type ID string
+
+func (id ID) String() string {
+ return string(id)
+}
diff --git a/pkg/domain/sparkle.go b/pkg/domain/sparkle.go
index 3f8b4f4..5c07fbc 100644
--- a/pkg/domain/sparkle.go
+++ b/pkg/domain/sparkle.go
@@ -8,9 +8,11 @@ import (
)
type Sparkle struct {
- ID string `json:"id" jsonapi:"primary,sparkles"`
- Sparklee string `json:"sparklee" jsonapi:"attr,sparklee"`
- Reason string `json:"reason" jsonapi:"attr,reason"`
+ ID ID `json:"id" jsonapi:"primary,sparkles"`
+ Sparklee string `json:"sparklee" jsonapi:"attr,sparklee"`
+ Recipient *User `json:"recipient" jsonapi:"attr,recipient"`
+ Author *User `json:"author" jsonapi:"attr,author"`
+ Reason string `json:"reason" jsonapi:"attr,reason"`
}
var SparkleRegex = regexp.MustCompile(`\A\s*(?P<sparklee>@\w+)\s+(?P<reason>.+)\z`)
@@ -33,17 +35,17 @@ func NewSparkle(text string) (*Sparkle, error) {
}
return &Sparkle{
- ID: pls.GenerateULID(),
+ ID: ID(pls.GenerateULID()),
Sparklee: matches[SparkleeIndex],
Reason: matches[ReasonIndex],
}, nil
}
-func (s *Sparkle) GetID() string {
+func (s *Sparkle) GetID() ID {
return s.ID
}
-func (s *Sparkle) SetID(id string) error {
+func (s *Sparkle) SetID(id ID) error {
s.ID = id
return nil
}
diff --git a/pkg/domain/user.go b/pkg/domain/user.go
index 4053e8f..5a0420b 100644
--- a/pkg/domain/user.go
+++ b/pkg/domain/user.go
@@ -1,22 +1,34 @@
package domain
type User struct {
- ID string `json:"id" jsonapi:"primary,users"`
+ ID ID `json:"id" jsonapi:"primary,users"`
+ Username string `json:"username" jsonapi:"attr,username"`
+ Email string `json:"email" jsonapi:"attr,email"`
+ ProfileURL string `json:"profile" jsonapi:"attr,profile"`
+ Picture string `json:"picture" jsonapi:"attr,picture"`
}
func NewUser() *User {
return &User{}
}
-func (s *User) GetID() string {
- return s.ID
+func (u *User) GetID() ID {
+ return u.ID
}
-func (s *User) SetID(id string) error {
- s.ID = id
+func (u *User) SetID(id ID) error {
+ u.ID = id
return nil
}
-func (s *User) Validate() error {
+func (u *User) Validate() error {
return nil
}
+
+func (self *User) Sparkle(recipient *User, reason string) *Sparkle {
+ return &Sparkle{
+ Recipient: recipient,
+ Author: self,
+ Reason: reason,
+ }
+}
diff --git a/pkg/domain/user_test.go b/pkg/domain/user_test.go
new file mode 100644
index 0000000..dbdba6d
--- /dev/null
+++ b/pkg/domain/user_test.go
@@ -0,0 +1,24 @@
+package domain
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func TestUser(t *testing.T) {
+ t.Run("Sparkle", func(t *testing.T) {
+ t.Run("returns a new Sparkle", func(t *testing.T) {
+ tanuki := &User{}
+ user := &User{}
+
+ sparkle := user.Sparkle(tanuki, "for helping me with my homework")
+
+ require.NotNil(t, sparkle)
+ assert.Equal(t, tanuki, sparkle.Recipient)
+ assert.Equal(t, "for helping me with my homework", sparkle.Reason)
+ assert.Equal(t, user, sparkle.Author)
+ })
+ })
+}
diff --git a/pkg/oidc/id_token.go b/pkg/oidc/id_token.go
index 962f372..aa79db7 100644
--- a/pkg/oidc/id_token.go
+++ b/pkg/oidc/id_token.go
@@ -52,6 +52,15 @@ Example ID Token from GitLab OIDC Provider:
*/
type IDToken = oidc.IDToken
+type CustomClaims struct {
+ Name string `json:"name"`
+ Nickname string `json:"nickname"`
+ Email string `json:"email"`
+ ProfileURL string `json:"profile"`
+ Picture string `json:"picture"`
+ Groups []string `json:"groups_direct"`
+}
+
type RawToken string
func (r RawToken) String() string {
diff --git a/pkg/web/middleware/init.go b/pkg/web/middleware/init.go
index ccf4836..ac06c32 100644
--- a/pkg/web/middleware/init.go
+++ b/pkg/web/middleware/init.go
@@ -8,8 +8,17 @@ import (
func init() {
mapper.Register(func(idToken *oidc.IDToken) *domain.User {
+ customClaims := &oidc.CustomClaims{}
+ if err := idToken.Claims(customClaims); err != nil {
+ return &domain.User{ID: domain.ID(idToken.Subject)}
+ }
+
return &domain.User{
- ID: idToken.Subject,
+ ID: domain.ID(idToken.Subject),
+ Username: customClaims.Nickname,
+ Email: customClaims.Email,
+ ProfileURL: customClaims.ProfileURL,
+ Picture: customClaims.Picture,
}
})
}
diff --git a/pkg/web/middleware/user.go b/pkg/web/middleware/user.go
index 1e95ce0..68d2daa 100644
--- a/pkg/web/middleware/user.go
+++ b/pkg/web/middleware/user.go
@@ -21,7 +21,7 @@ func User(db db.Repository[*domain.User]) func(http.Handler) http.Handler {
return
}
- user := db.Find(idToken.Subject)
+ user := db.Find(domain.ID(idToken.Subject))
if x.IsZero(user) {
user = mapper.MapFrom[*oidc.IDToken, *domain.User](idToken)
if err := db.Save(user); err != nil {
diff --git a/pkg/web/middleware/user_test.go b/pkg/web/middleware/user_test.go
index 447d54d..c18bfdb 100644
--- a/pkg/web/middleware/user_test.go
+++ b/pkg/web/middleware/user_test.go
@@ -18,7 +18,7 @@ func TestUser(t *testing.T) {
repository := db.NewRepository[*domain.User]()
middleware := User(repository)
- knownUser := &domain.User{ID: pls.GenerateULID()}
+ knownUser := &domain.User{ID: domain.ID(pls.GenerateULID())}
require.NoError(t, repository.Save(knownUser))
t.Run("when ID Token is provided", func(t *testing.T) {
@@ -31,7 +31,7 @@ func TestUser(t *testing.T) {
w.WriteHeader(http.StatusTeapot)
}))
- ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID})
+ ctx := key.IDToken.With(t.Context(), &oidc.IDToken{Subject: knownUser.ID.String()})
r, w := test.RequestResponse("GET", "/example", test.WithContext(ctx))
server.ServeHTTP(w, r)
@@ -45,7 +45,7 @@ func TestUser(t *testing.T) {
server := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
user := key.CurrentUser.From(r.Context())
require.NotNil(t, user)
- assert.Equal(t, unknownID, user.ID)
+ assert.Equal(t, domain.ID(unknownID), user.ID)
w.WriteHeader(http.StatusTeapot)
}))
@@ -56,7 +56,7 @@ func TestUser(t *testing.T) {
server.ServeHTTP(w, r)
assert.Equal(t, http.StatusTeapot, w.Code)
- require.NotNil(t, repository.Find(unknownID))
+ require.NotNil(t, repository.Find(domain.ID(unknownID)))
})
})