package policies import ( "fmt" "testing" "github.com/cedar-policy/cedar-go" "github.com/stretchr/testify/assert" "gitlab.com/mokhax/spike/pkg/gid" ) func build(f func(*cedar.Request)) *cedar.Request { request := &cedar.Request{ Principal: gid.NewEntityUID("gid://example/User/1"), Action: cedar.NewEntityUID("HttpMethod", "GET"), Resource: cedar.NewEntityUID("HttpPath", "/"), Context: cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("idp.example.com"), }), } f(request) return request } func TestAllowed(t *testing.T) { allowed := []*cedar.Request{ build(func(r *cedar.Request) {}), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Action = cedar.NewEntityUID("HttpMethod", "POST") }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Action = cedar.NewEntityUID("HttpMethod", "PUT") }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Action = cedar.NewEntityUID("HttpMethod", "PATCH") }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Action = cedar.NewEntityUID("HttpMethod", "DELETE") }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Action = cedar.NewEntityUID("HttpMethod", "HEAD") }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Resource = cedar.NewEntityUID("HttpPath", "/organizations.json") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("api.example.com"), }) }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Resource = cedar.NewEntityUID("HttpPath", "/groups.json") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("api.example.com"), }) }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/openid-configuration") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("idp.example.com"), }) }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/oauth-authorization-server") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("idp.example.com"), }) }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/*") r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/openid-configuration") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("idp.example.com"), }) }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/*") r.Resource = cedar.NewEntityUID("HttpPath", "/.well-known/oauth-authorization-server") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("idp.example.com"), }) }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Action = cedar.NewEntityUID("HttpMethod", "POST") r.Resource = cedar.NewEntityUID("HttpPath", "/twirp/authx.rpc.Ability/Allowed") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("idp.example.com"), }) }), build(func(r *cedar.Request) { r.Principal = gid.NewEntityUID("gid://example/User/1") r.Action = cedar.NewEntityUID("HttpMethod", "GET") r.Resource = cedar.NewEntityUID("HttpPath", "/index.html") r.Context = cedar.NewRecord(cedar.RecordMap{ "host": cedar.String("ui.example.com"), }) }), } for _, tt := range allowed { t.Run(fmt.Sprintf("allows: %v/%v %v %v%v", tt.Principal.Type, tt.Principal.ID, tt.Action.ID, tt.Context.Map()["host"], tt.Resource.ID), func(t *testing.T) { assert.True(t, Allowed(*tt)) }) } denied := []*cedar.Request{ build(func(r *cedar.Request) { r.Principal = gid.ZeroEntityUID() r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("POST")) }), build(func(r *cedar.Request) { r.Principal = gid.ZeroEntityUID() r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("PUT")) }), build(func(r *cedar.Request) { r.Principal = gid.ZeroEntityUID() r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("PATCH")) }), build(func(r *cedar.Request) { r.Principal = gid.ZeroEntityUID() r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("DELETE")) }), build(func(r *cedar.Request) { r.Principal = gid.ZeroEntityUID() r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("HEAD")) }), build(func(r *cedar.Request) { r.Principal = gid.ZeroEntityUID() r.Action = cedar.NewEntityUID("HttpMethod", cedar.String("TRACE")) }), } for _, tt := range denied { t.Run(fmt.Sprintf("denies: %v/%v %v %v%v", tt.Principal.Type, tt.Principal.ID, tt.Action.ID, tt.Context.Map()["host"], tt.Resource.ID), func(t *testing.T) { assert.False(t, Allowed(*tt)) }) } }