From 82a137bf926f2268b7559a9bb8e97df03780e1e3 Mon Sep 17 00:00:00 2001 From: mo khan Date: Fri, 12 Sep 2025 17:22:39 -0600 Subject: feat: inject project ids in request before forwarding --- pkg/authz/check_service.go | 49 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/pkg/authz/check_service.go b/pkg/authz/check_service.go index 7d07069c..38e8b410 100644 --- a/pkg/authz/check_service.go +++ b/pkg/authz/check_service.go @@ -2,8 +2,10 @@ package authz import ( "context" + "io" "net/http" "path/filepath" + "strings" v1 "github.com/authzed/authzed-go/proto/authzed/api/v1" authzed "github.com/authzed/authzed-go/v1" @@ -33,7 +35,7 @@ func NewCheckService(client *authzed.Client) auth.AuthorizationServer { func (svc *CheckService) Check(ctx context.Context, request *auth.CheckRequest) (*auth.CheckResponse, error) { if svc.isAuthorized(ctx, request) { - return svc.OK(ctx), nil + return svc.OK(ctx, svc.injectHeaders(ctx, request)), nil } return svc.Denied(ctx), nil } @@ -80,19 +82,19 @@ func (svc *CheckService) validRequest(ctx context.Context, r *auth.CheckRequest) x.IsPresent(r.Attributes.Request.Http) } -func (svc *CheckService) OK(ctx context.Context) *auth.CheckResponse { +func (svc *CheckService) OK(ctx context.Context, f x.Option[*auth.CheckResponse_OkResponse]) *auth.CheckResponse { log.WithFields(ctx, log.Fields{"authorized": true}) return &auth.CheckResponse{ Status: &status.Status{ Code: int32(codes.OK), }, - HttpResponse: &auth.CheckResponse_OkResponse{ + HttpResponse: f(&auth.CheckResponse_OkResponse{ OkResponse: &auth.OkHttpResponse{ Headers: []*core.HeaderValueOption{}, HeadersToRemove: []string{}, ResponseHeadersToAdd: []*core.HeaderValueOption{}, }, - }, + }), } } @@ -112,3 +114,42 @@ func (svc *CheckService) Denied(ctx context.Context) *auth.CheckResponse { }, } } + +func (svc *CheckService) injectHeaders(ctx context.Context, request *auth.CheckRequest) x.Option[*auth.CheckResponse_OkResponse] { + return x.With[*auth.CheckResponse_OkResponse](func(response *auth.CheckResponse_OkResponse) { + if x.IsZero(svc.client) { + return + } + + stream, err := svc.client.LookupResources(ctx, &v1.LookupResourcesRequest{ + ResourceObjectType: "project", + Permission: "read_project", + Subject: mapper.MapFrom[*auth.CheckRequest, *v1.SubjectReference](request), + }) + if err != nil { + pls.LogError(ctx, err) + return + } + + var projectIDs []string + for { + result, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + pls.LogError(ctx, err) + break + } + projectIDs = append(projectIDs, result.ResourceObjectId) + } + + response.OkResponse.Headers = append(response.OkResponse.Headers, &core.HeaderValueOption{ + Header: &core.HeaderValue{ + Key: "x-project-ids", + Value: strings.Join(projectIDs, ","), + }, + AppendAction: core.HeaderValueOption_OVERWRITE_IF_EXISTS_OR_ADD, + }) + }) +} -- cgit v1.2.3