From c583bcd1473205104a1e1af812ed4976d30c7baa Mon Sep 17 00:00:00 2001 From: mo khan Date: Fri, 2 May 2025 14:29:41 -0600 Subject: refactor: remove anything unrelated to the authz daemon --- pkg/app/app.go | 25 - pkg/app/routes.go | 18 - pkg/authz/authz.go | 23 - pkg/authz/casbin.go | 43 - pkg/authz/cedar.go | 34 - pkg/authz/token.go | 30 - pkg/cfg/cfg.go | 34 - pkg/cfg/mux.go | 11 - pkg/cfg/option.go | 3 - pkg/cfg/tls.go | 75 -- pkg/policies/policies_test.go | 2 +- pkg/prxy/prxy.go | 43 - pkg/prxy/prxy_test.go | 49 - pkg/rpc/ability.pb.go | 12 +- pkg/rpc/ability.twirp.go | 1104 +++++++++++++++++++ pkg/rpc/ability_grpc.pb.go | 121 --- pkg/rpc/ability_service.go | 5 +- .../mokhax/spike/pkg/rpc/ability.twirp.go | 1105 -------------------- pkg/rpc/server.go | 20 +- pkg/rpc/server_test.go | 26 +- pkg/srv/srv.go | 26 - pkg/test/test.go | 49 - 22 files changed, 1133 insertions(+), 1725 deletions(-) delete mode 100644 pkg/app/app.go delete mode 100644 pkg/app/routes.go delete mode 100644 pkg/authz/authz.go delete mode 100644 pkg/authz/casbin.go delete mode 100644 pkg/authz/cedar.go delete mode 100644 pkg/authz/token.go delete mode 100644 pkg/cfg/cfg.go delete mode 100644 pkg/cfg/mux.go delete mode 100644 pkg/cfg/option.go delete mode 100644 pkg/cfg/tls.go delete mode 100644 pkg/prxy/prxy.go delete mode 100644 pkg/prxy/prxy_test.go create mode 100644 pkg/rpc/ability.twirp.go delete mode 100644 pkg/rpc/ability_grpc.pb.go delete mode 100644 pkg/rpc/gitlab.com/mokhax/spike/pkg/rpc/ability.twirp.go delete mode 100644 pkg/srv/srv.go delete mode 100644 pkg/test/test.go (limited to 'pkg') diff --git a/pkg/app/app.go b/pkg/app/app.go deleted file mode 100644 index 89a2bd34..00000000 --- a/pkg/app/app.go +++ /dev/null @@ -1,25 +0,0 @@ -package app - -import ( - "os" - - "github.com/xlgmokha/x/pkg/log" - "gitlab.com/mokhax/spike/pkg/authz" - "gitlab.com/mokhax/spike/pkg/cfg" - "gitlab.com/mokhax/spike/pkg/srv" -) - -func Start(bindAddr string) error { - logger := log.New(os.Stdout, log.Fields{"app": "gtwy"}) - mux := authz.HTTP(authz.WithCasbin(), Routes()) - return srv.Run(cfg.New( - bindAddr, - cfg.WithMux(log.HTTP(logger)(mux)), - cfg.WithTLS([]string{ - "api.example.com", - "authzd.example.com", - "idp.example.com", - "ui.example.com", - }), - )) -} diff --git a/pkg/app/routes.go b/pkg/app/routes.go deleted file mode 100644 index ff1291c2..00000000 --- a/pkg/app/routes.go +++ /dev/null @@ -1,18 +0,0 @@ -package app - -import ( - "net/http" - - "gitlab.com/mokhax/spike/pkg/prxy" -) - -func Routes() http.Handler { - mux := http.NewServeMux() - mux.Handle("/", prxy.New(map[string]string{ - "api.example.com": "http://localhost:8284", - "authzd.example.com": "http://localhost:50051", - "idp.example.com": "http://localhost:8282", - "ui.example.com": "http://localhost:8283", - })) - return mux -} diff --git a/pkg/authz/authz.go b/pkg/authz/authz.go deleted file mode 100644 index 5a93a29c..00000000 --- a/pkg/authz/authz.go +++ /dev/null @@ -1,23 +0,0 @@ -package authz - -import "net/http" - -type Authorizer interface { - Authorize(*http.Request) bool -} - -type AuthorizerFunc func(*http.Request) bool - -func (f AuthorizerFunc) Authorize(r *http.Request) bool { - return f(r) -} - -func HTTP(authorizer Authorizer, h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if authorizer.Authorize(r) { - h.ServeHTTP(w, r) - } else { - w.WriteHeader(http.StatusForbidden) - } - }) -} diff --git a/pkg/authz/casbin.go b/pkg/authz/casbin.go deleted file mode 100644 index 140bdb98..00000000 --- a/pkg/authz/casbin.go +++ /dev/null @@ -1,43 +0,0 @@ -package authz - -import ( - "fmt" - "net" - "net/http" - - "github.com/casbin/casbin/v3" - "github.com/xlgmokha/x/pkg/log" - "github.com/xlgmokha/x/pkg/x" -) - -func WithCasbin() Authorizer { - enforcer := x.Must(casbin.NewEnforcer("casbin.conf", "casbin.csv")) - - return AuthorizerFunc(func(r *http.Request) bool { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - log.WithFields(r.Context(), log.Fields{"error": err}) - return false - } - - subject, found := TokenFrom(r).Subject() - if !found { - subject = "*" - } - ok, err := enforcer.Enforce(subject, host, r.Method, r.URL.Path) - if err != nil { - log.WithFields(r.Context(), log.Fields{"error": err}) - return false - } - - fmt.Printf("%v: %v -> %v %v%v\n", ok, subject, r.Method, host, r.URL.Path) - log.WithFields(r.Context(), log.Fields{ - "authz": ok, - "subject": subject, - "action": r.Method, - "domain": host, - "object": r.URL.Path, - }) - return ok - }) -} diff --git a/pkg/authz/cedar.go b/pkg/authz/cedar.go deleted file mode 100644 index 18674c74..00000000 --- a/pkg/authz/cedar.go +++ /dev/null @@ -1,34 +0,0 @@ -package authz - -import ( - "net" - "net/http" - - cedar "github.com/cedar-policy/cedar-go" - "github.com/xlgmokha/x/pkg/log" - "gitlab.com/mokhax/spike/pkg/gid" - "gitlab.com/mokhax/spike/pkg/policies" -) - -func WithCedar() Authorizer { - return AuthorizerFunc(func(r *http.Request) bool { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - log.WithFields(r.Context(), log.Fields{"error": err}) - return false - } - subject, found := TokenFrom(r).Subject() - if !found { - subject = "gid://example/User/*" - } - - return policies.Allowed(cedar.Request{ - Principal: gid.NewEntityUID(subject), - Action: cedar.NewEntityUID("HttpMethod", cedar.String(r.Method)), - Resource: cedar.NewEntityUID("HttpPath", cedar.String(r.URL.Path)), - Context: cedar.NewRecord(cedar.RecordMap{ - "host": cedar.String(host), - }), - }) - }) -} diff --git a/pkg/authz/token.go b/pkg/authz/token.go deleted file mode 100644 index 2794bf4a..00000000 --- a/pkg/authz/token.go +++ /dev/null @@ -1,30 +0,0 @@ -package authz - -import ( - "net/http" - "strings" - - "github.com/lestrrat-go/jwx/v3/jwt" - "github.com/xlgmokha/x/pkg/log" -) - -func TokenFrom(r *http.Request) jwt.Token { - authorization := r.Header.Get("Authorization") - if authorization == "" || !strings.Contains(authorization, "Bearer") { - return jwt.New() - } - - token, err := jwt.ParseRequest(r, - jwt.WithContext(r.Context()), - jwt.WithHeaderKey("Authorization"), - jwt.WithValidate(false), // TODO:: Connect this to a JSON Web Key Set - jwt.WithVerify(false), // TODO:: Connect this to a JSON Web Key Set - ) - - if err != nil { - log.WithFields(r.Context(), log.Fields{"error": err}) - return jwt.New() - } - - return token -} diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go deleted file mode 100644 index 0d7a6427..00000000 --- a/pkg/cfg/cfg.go +++ /dev/null @@ -1,34 +0,0 @@ -package cfg - -import ( - "crypto/tls" - "net/http" -) - -type Config struct { - BindAddress string - Mux http.Handler - TLS *tls.Config -} - -func New(addr string, options ...Option) *Config { - if addr == "" { - addr = ":0" - } - - c := &Config{ - BindAddress: addr, - Mux: http.DefaultServeMux, - } - for _, option := range options { - option(c) - } - return c -} - -func (c *Config) Run(server *http.Server) error { - if c.TLS != nil { - return server.ListenAndServeTLS("", "") - } - return server.ListenAndServe() -} diff --git a/pkg/cfg/mux.go b/pkg/cfg/mux.go deleted file mode 100644 index 6c6f4375..00000000 --- a/pkg/cfg/mux.go +++ /dev/null @@ -1,11 +0,0 @@ -package cfg - -import ( - "net/http" -) - -func WithMux(mux http.Handler) Option { - return func(config *Config) { - config.Mux = mux - } -} diff --git a/pkg/cfg/option.go b/pkg/cfg/option.go deleted file mode 100644 index 0f3e87d8..00000000 --- a/pkg/cfg/option.go +++ /dev/null @@ -1,3 +0,0 @@ -package cfg - -type Option func(*Config) diff --git a/pkg/cfg/tls.go b/pkg/cfg/tls.go deleted file mode 100644 index bce6e186..00000000 --- a/pkg/cfg/tls.go +++ /dev/null @@ -1,75 +0,0 @@ -package cfg - -import ( - "context" - "crypto/tls" - "crypto/x509" - "encoding/pem" - "io/ioutil" - "net/http" - "os" - "path/filepath" - - "github.com/caddyserver/certmagic" - "github.com/xlgmokha/x/pkg/x" - "go.uber.org/zap" -) - -func WithSelfSigned(cert, key string) Option { - certificate := x.Must(tls.LoadX509KeyPair(cert, key)) - - return func(config *Config) { - config.TLS = &tls.Config{ - MinVersion: tls.VersionTLS13, - Certificates: []tls.Certificate{certificate}, - } - } -} - -func WithTLS(domainNames []string) Option { - directoryURL := "https://localhost:8081/acme/acme/directory" - storage := &certmagic.FileStorage{ - Path: filepath.Join(x.Must(os.Getwd()), "/tmp/cache"), - } - var cache *certmagic.Cache - cache = certmagic.NewCache(certmagic.CacheOptions{ - GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) { - return certmagic.New(cache, certmagic.Config{ - Logger: x.Must(zap.NewProduction()), - OnDemand: new(certmagic.OnDemandConfig), - Storage: storage, - }), nil - }, - }) - roots := x.Must(x509.SystemCertPool()) - roots.AddCert(func() *x509.Certificate { - block, _ := pem.Decode(x.Must(ioutil.ReadFile( - filepath.Join(x.Must(os.Getwd()), "/tmp/step/certs/root_ca.crt"), - ))) - return x.Must(x509.ParseCertificate(block.Bytes)) - }()) - magic := certmagic.New(cache, certmagic.Config{ - Logger: x.Must(zap.NewProduction()), - OnDemand: new(certmagic.OnDemandConfig), - Storage: storage, - }) - issuer := certmagic.NewACMEIssuer(magic, certmagic.ACMEIssuer{ - Agreed: true, - Email: "email@example.com", - CA: directoryURL, - TestCA: directoryURL, - TrustedRoots: roots, - }) - magic.Issuers = []certmagic.Issuer{issuer} - - if err := http.ListenAndServe(":80", issuer.HTTPChallengeHandler(http.DefaultServeMux)); err != nil { - return func(*Config) {} - } - - x.Check(magic.ManageSync(context.Background(), domainNames)) - - return func(config *Config) { - config.TLS = magic.TLSConfig() - config.TLS.NextProtos = append([]string{"h2", "http/1.1"}, config.TLS.NextProtos...) - } -} diff --git a/pkg/policies/policies_test.go b/pkg/policies/policies_test.go index 24ef6c68..9dc98bcd 100644 --- a/pkg/policies/policies_test.go +++ b/pkg/policies/policies_test.go @@ -6,7 +6,7 @@ import ( "github.com/cedar-policy/cedar-go" "github.com/stretchr/testify/assert" - "gitlab.com/mokhax/spike/pkg/gid" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/authz.d/pkg/gid" ) func build(f func(*cedar.Request)) *cedar.Request { diff --git a/pkg/prxy/prxy.go b/pkg/prxy/prxy.go deleted file mode 100644 index 43565bd3..00000000 --- a/pkg/prxy/prxy.go +++ /dev/null @@ -1,43 +0,0 @@ -package prxy - -import ( - "fmt" - "net" - "net/http" - "net/http/httputil" - "net/url" - - "github.com/xlgmokha/x/pkg/log" - "github.com/xlgmokha/x/pkg/x" -) - -func New(routes map[string]string) http.Handler { - mapped := map[string]*url.URL{} - for source, destination := range routes { - mapped[source] = x.Must(url.Parse(destination)) - } - - return &httputil.ReverseProxy{ - Rewrite: func(r *httputil.ProxyRequest) { - host, _, err := net.SplitHostPort(r.In.Host) - if err != nil { - log.WithFields(r.In.Context(), log.Fields{"error": err}) - return - } - - destination := mapped[host] - r.SetXForwarded() - r.SetURL(destination) - }, - Transport: http.DefaultTransport, - FlushInterval: -1, - ErrorLog: nil, - ModifyResponse: func(r *http.Response) error { - r.Header.Add("Via", fmt.Sprintf("%v gtwy", r.Proto)) - return nil - }, - ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { - log.WithFields(r.Context(), log.Fields{"error": err}) - }, - } -} diff --git a/pkg/prxy/prxy_test.go b/pkg/prxy/prxy_test.go deleted file mode 100644 index 6f37974e..00000000 --- a/pkg/prxy/prxy_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package prxy - -import ( - "net/http" - "net/http/httptest" - "net/url" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/xlgmokha/x/pkg/x" - "gitlab.com/mokhax/spike/pkg/test" -) - -func TestProxy(t *testing.T) { - t.Run("http://idp.test", func(t *testing.T) { - var lastIdPRequest *http.Request - var lastUiRequest *http.Request - - idp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lastIdPRequest = r - w.WriteHeader(http.StatusOK) - })) - defer idp.Close() - - ui := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - lastUiRequest = r - w.WriteHeader(http.StatusTeapot) - })) - defer ui.Close() - - subject := New(map[string]string{ - "idp.test": idp.URL, - "ui.test": ui.URL, - }) - - r, w := test.RequestResponse("GET", "http://idp.test:8080/saml/new") - - subject.ServeHTTP(w, r) - - url := x.Must(url.Parse(idp.URL)) - - assert.Nil(t, lastUiRequest) - assert.Equal(t, http.StatusOK, w.Code) - - require.NotNil(t, lastIdPRequest) - assert.Equal(t, url.Host, lastIdPRequest.Host) - }) -} diff --git a/pkg/rpc/ability.pb.go b/pkg/rpc/ability.pb.go index 48dd0b24..939719fc 100644 --- a/pkg/rpc/ability.pb.go +++ b/pkg/rpc/ability.pb.go @@ -129,7 +129,7 @@ var File_ability_proto protoreflect.FileDescriptor const file_ability_proto_rawDesc = "" + "\n" + - "\rability.proto\x12\tauthx.rpc\"d\n" + + "\rability.proto\x12\tauthz.rpc\"d\n" + "\fAllowRequest\x12\x18\n" + "\asubject\x18\x01 \x01(\tR\asubject\x12\x1e\n" + "\n" + @@ -140,7 +140,7 @@ const file_ability_proto_rawDesc = "" + "AllowReply\x12\x16\n" + "\x06result\x18\x01 \x01(\bR\x06result2F\n" + "\aAbility\x12;\n" + - "\aAllowed\x12\x17.authx.rpc.AllowRequest\x1a\x15.authx.rpc.AllowReply\"\x00B!Z\x1fgitlab.com/mokhax/spike/pkg/rpcb\x06proto3" + "\aAllowed\x12\x17.authz.rpc.AllowRequest\x1a\x15.authz.rpc.AllowReply\"\x00B\tZ\apkg/rpcb\x06proto3" var ( file_ability_proto_rawDescOnce sync.Once @@ -156,12 +156,12 @@ func file_ability_proto_rawDescGZIP() []byte { var file_ability_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_ability_proto_goTypes = []any{ - (*AllowRequest)(nil), // 0: authx.rpc.AllowRequest - (*AllowReply)(nil), // 1: authx.rpc.AllowReply + (*AllowRequest)(nil), // 0: authz.rpc.AllowRequest + (*AllowReply)(nil), // 1: authz.rpc.AllowReply } var file_ability_proto_depIdxs = []int32{ - 0, // 0: authx.rpc.Ability.Allowed:input_type -> authx.rpc.AllowRequest - 1, // 1: authx.rpc.Ability.Allowed:output_type -> authx.rpc.AllowReply + 0, // 0: authz.rpc.Ability.Allowed:input_type -> authz.rpc.AllowRequest + 1, // 1: authz.rpc.Ability.Allowed:output_type -> authz.rpc.AllowReply 1, // [1:2] is the sub-list for method output_type 0, // [0:1] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name diff --git a/pkg/rpc/ability.twirp.go b/pkg/rpc/ability.twirp.go new file mode 100644 index 00000000..f5a33296 --- /dev/null +++ b/pkg/rpc/ability.twirp.go @@ -0,0 +1,1104 @@ +// Code generated by protoc-gen-twirp v8.1.3, DO NOT EDIT. +// source: ability.proto + +package rpc + +import context "context" +import fmt "fmt" +import http "net/http" +import io "io" +import json "encoding/json" +import strconv "strconv" +import strings "strings" + +import protojson "google.golang.org/protobuf/encoding/protojson" +import proto "google.golang.org/protobuf/proto" +import twirp "github.com/twitchtv/twirp" +import ctxsetters "github.com/twitchtv/twirp/ctxsetters" + +import bytes "bytes" +import errors "errors" +import path "path" +import url "net/url" + +// Version compatibility assertion. +// If the constant is not defined in the package, that likely means +// the package needs to be updated to work with this generated code. +// See https://twitchtv.github.io/twirp/docs/version_matrix.html +const _ = twirp.TwirpPackageMinVersion_8_1_0 + +// ================= +// Ability Interface +// ================= + +type Ability interface { + Allowed(context.Context, *AllowRequest) (*AllowReply, error) +} + +// ======================= +// Ability Protobuf Client +// ======================= + +type abilityProtobufClient struct { + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions +} + +// NewAbilityProtobufClient creates a Protobuf client that implements the Ability interface. +// It communicates using Protobuf and can be configured with a custom HTTPClient. +func NewAbilityProtobufClient(baseURL string, client HTTPClient, opts ...twirp.ClientOption) Ability { + if c, ok := client.(*http.Client); ok { + client = withoutRedirects(c) + } + + clientOpts := twirp.ClientOptions{} + for _, o := range opts { + o(&clientOpts) + } + + // Using ReadOpt allows backwards and forwards compatibility with new options in the future + literalURLs := false + _ = clientOpts.ReadOpt("literalURLs", &literalURLs) + var pathPrefix string + if ok := clientOpts.ReadOpt("pathPrefix", &pathPrefix); !ok { + pathPrefix = "/twirp" // default prefix + } + + // Build method URLs: []/./ + serviceURL := sanitizeBaseURL(baseURL) + serviceURL += baseServicePath(pathPrefix, "authz.rpc", "Ability") + urls := [1]string{ + serviceURL + "Allowed", + } + + return &abilityProtobufClient{ + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, + } +} + +func (c *abilityProtobufClient) Allowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { + ctx = ctxsetters.WithPackageName(ctx, "authz.rpc") + ctx = ctxsetters.WithServiceName(ctx, "Ability") + ctx = ctxsetters.WithMethodName(ctx, "Allowed") + caller := c.callAllowed + if c.interceptor != nil { + caller = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*AllowRequest) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") + } + return c.callAllowed(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*AllowReply) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *abilityProtobufClient) callAllowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { + out := new(AllowReply) + ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) + if err != nil { + twerr, ok := err.(twirp.Error) + if !ok { + twerr = twirp.InternalErrorWith(err) + } + callClientError(ctx, c.opts.Hooks, twerr) + return nil, err + } + + callClientResponseReceived(ctx, c.opts.Hooks) + + return out, nil +} + +// =================== +// Ability JSON Client +// =================== + +type abilityJSONClient struct { + client HTTPClient + urls [1]string + interceptor twirp.Interceptor + opts twirp.ClientOptions +} + +// NewAbilityJSONClient creates a JSON client that implements the Ability interface. +// It communicates using JSON and can be configured with a custom HTTPClient. +func NewAbilityJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOption) Ability { + if c, ok := client.(*http.Client); ok { + client = withoutRedirects(c) + } + + clientOpts := twirp.ClientOptions{} + for _, o := range opts { + o(&clientOpts) + } + + // Using ReadOpt allows backwards and forwards compatibility with new options in the future + literalURLs := false + _ = clientOpts.ReadOpt("literalURLs", &literalURLs) + var pathPrefix string + if ok := clientOpts.ReadOpt("pathPrefix", &pathPrefix); !ok { + pathPrefix = "/twirp" // default prefix + } + + // Build method URLs: []/./ + serviceURL := sanitizeBaseURL(baseURL) + serviceURL += baseServicePath(pathPrefix, "authz.rpc", "Ability") + urls := [1]string{ + serviceURL + "Allowed", + } + + return &abilityJSONClient{ + client: client, + urls: urls, + interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), + opts: clientOpts, + } +} + +func (c *abilityJSONClient) Allowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { + ctx = ctxsetters.WithPackageName(ctx, "authz.rpc") + ctx = ctxsetters.WithServiceName(ctx, "Ability") + ctx = ctxsetters.WithMethodName(ctx, "Allowed") + caller := c.callAllowed + if c.interceptor != nil { + caller = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { + resp, err := c.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*AllowRequest) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") + } + return c.callAllowed(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*AllowReply) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + return caller(ctx, in) +} + +func (c *abilityJSONClient) callAllowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { + out := new(AllowReply) + ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) + if err != nil { + twerr, ok := err.(twirp.Error) + if !ok { + twerr = twirp.InternalErrorWith(err) + } + callClientError(ctx, c.opts.Hooks, twerr) + return nil, err + } + + callClientResponseReceived(ctx, c.opts.Hooks) + + return out, nil +} + +// ====================== +// Ability Server Handler +// ====================== + +type abilityServer struct { + Ability + interceptor twirp.Interceptor + hooks *twirp.ServerHooks + pathPrefix string // prefix for routing + jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response + jsonCamelCase bool // JSON fields are serialized as lowerCamelCase rather than keeping the original proto names +} + +// NewAbilityServer builds a TwirpServer that can be used as an http.Handler to handle +// HTTP requests that are routed to the right method in the provided svc implementation. +// The opts are twirp.ServerOption modifiers, for example twirp.WithServerHooks(hooks). +func NewAbilityServer(svc Ability, opts ...interface{}) TwirpServer { + serverOpts := newServerOpts(opts) + + // Using ReadOpt allows backwards and forwards compatibility with new options in the future + jsonSkipDefaults := false + _ = serverOpts.ReadOpt("jsonSkipDefaults", &jsonSkipDefaults) + jsonCamelCase := false + _ = serverOpts.ReadOpt("jsonCamelCase", &jsonCamelCase) + var pathPrefix string + if ok := serverOpts.ReadOpt("pathPrefix", &pathPrefix); !ok { + pathPrefix = "/twirp" // default prefix + } + + return &abilityServer{ + Ability: svc, + hooks: serverOpts.Hooks, + interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), + pathPrefix: pathPrefix, + jsonSkipDefaults: jsonSkipDefaults, + jsonCamelCase: jsonCamelCase, + } +} + +// writeError writes an HTTP response with a valid Twirp error format, and triggers hooks. +// If err is not a twirp.Error, it will get wrapped with twirp.InternalErrorWith(err) +func (s *abilityServer) writeError(ctx context.Context, resp http.ResponseWriter, err error) { + writeError(ctx, resp, err, s.hooks) +} + +// handleRequestBodyError is used to handle error when the twirp server cannot read request +func (s *abilityServer) handleRequestBodyError(ctx context.Context, resp http.ResponseWriter, msg string, err error) { + if context.Canceled == ctx.Err() { + s.writeError(ctx, resp, twirp.NewError(twirp.Canceled, "failed to read request: context canceled")) + return + } + if context.DeadlineExceeded == ctx.Err() { + s.writeError(ctx, resp, twirp.NewError(twirp.DeadlineExceeded, "failed to read request: deadline exceeded")) + return + } + s.writeError(ctx, resp, twirp.WrapError(malformedRequestError(msg), err)) +} + +// AbilityPathPrefix is a convenience constant that may identify URL paths. +// Should be used with caution, it only matches routes generated by Twirp Go clients, +// with the default "/twirp" prefix and default CamelCase service and method names. +// More info: https://twitchtv.github.io/twirp/docs/routing.html +const AbilityPathPrefix = "/twirp/authz.rpc.Ability/" + +func (s *abilityServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { + ctx := req.Context() + ctx = ctxsetters.WithPackageName(ctx, "authz.rpc") + ctx = ctxsetters.WithServiceName(ctx, "Ability") + ctx = ctxsetters.WithResponseWriter(ctx, resp) + + var err error + ctx, err = callRequestReceived(ctx, s.hooks) + if err != nil { + s.writeError(ctx, resp, err) + return + } + + if req.Method != "POST" { + msg := fmt.Sprintf("unsupported method %q (only POST is allowed)", req.Method) + s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) + return + } + + // Verify path format: []/./ + prefix, pkgService, method := parseTwirpPath(req.URL.Path) + if pkgService != "authz.rpc.Ability" { + msg := fmt.Sprintf("no handler for path %q", req.URL.Path) + s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) + return + } + if prefix != s.pathPrefix { + msg := fmt.Sprintf("invalid path prefix %q, expected %q, on path %q", prefix, s.pathPrefix, req.URL.Path) + s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) + return + } + + switch method { + case "Allowed": + s.serveAllowed(ctx, resp, req) + return + default: + msg := fmt.Sprintf("no handler for path %q", req.URL.Path) + s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) + return + } +} + +func (s *abilityServer) serveAllowed(ctx context.Context, resp http.ResponseWriter, req *http.Request) { + header := req.Header.Get("Content-Type") + i := strings.Index(header, ";") + if i == -1 { + i = len(header) + } + switch strings.TrimSpace(strings.ToLower(header[:i])) { + case "application/json": + s.serveAllowedJSON(ctx, resp, req) + case "application/protobuf": + s.serveAllowedProtobuf(ctx, resp, req) + default: + msg := fmt.Sprintf("unexpected Content-Type: %q", req.Header.Get("Content-Type")) + twerr := badRouteError(msg, req.Method, req.URL.Path) + s.writeError(ctx, resp, twerr) + } +} + +func (s *abilityServer) serveAllowedJSON(ctx context.Context, resp http.ResponseWriter, req *http.Request) { + var err error + ctx = ctxsetters.WithMethodName(ctx, "Allowed") + ctx, err = callRequestRouted(ctx, s.hooks) + if err != nil { + s.writeError(ctx, resp, err) + return + } + + d := json.NewDecoder(req.Body) + rawReqBody := json.RawMessage{} + if err := d.Decode(&rawReqBody); err != nil { + s.handleRequestBodyError(ctx, resp, "the json request could not be decoded", err) + return + } + reqContent := new(AllowRequest) + unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} + if err = unmarshaler.Unmarshal(rawReqBody, reqContent); err != nil { + s.handleRequestBodyError(ctx, resp, "the json request could not be decoded", err) + return + } + + handler := s.Ability.Allowed + if s.interceptor != nil { + handler = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*AllowRequest) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") + } + return s.Ability.Allowed(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*AllowReply) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + + // Call service method + var respContent *AllowReply + func() { + defer ensurePanicResponses(ctx, resp, s.hooks) + respContent, err = handler(ctx, reqContent) + }() + + if err != nil { + s.writeError(ctx, resp, err) + return + } + if respContent == nil { + s.writeError(ctx, resp, twirp.InternalError("received a nil *AllowReply and nil error while calling Allowed. nil responses are not supported")) + return + } + + ctx = callResponsePrepared(ctx, s.hooks) + + marshaler := &protojson.MarshalOptions{UseProtoNames: !s.jsonCamelCase, EmitUnpopulated: !s.jsonSkipDefaults} + respBytes, err := marshaler.Marshal(respContent) + if err != nil { + s.writeError(ctx, resp, wrapInternal(err, "failed to marshal json response")) + return + } + + ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/json") + resp.Header().Set("Content-Length", strconv.Itoa(len(respBytes))) + resp.WriteHeader(http.StatusOK) + + if n, err := resp.Write(respBytes); err != nil { + msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) + twerr := twirp.NewError(twirp.Unknown, msg) + ctx = callError(ctx, s.hooks, twerr) + } + callResponseSent(ctx, s.hooks) +} + +func (s *abilityServer) serveAllowedProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { + var err error + ctx = ctxsetters.WithMethodName(ctx, "Allowed") + ctx, err = callRequestRouted(ctx, s.hooks) + if err != nil { + s.writeError(ctx, resp, err) + return + } + + buf, err := io.ReadAll(req.Body) + if err != nil { + s.handleRequestBodyError(ctx, resp, "failed to read request body", err) + return + } + reqContent := new(AllowRequest) + if err = proto.Unmarshal(buf, reqContent); err != nil { + s.writeError(ctx, resp, malformedRequestError("the protobuf request could not be decoded")) + return + } + + handler := s.Ability.Allowed + if s.interceptor != nil { + handler = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { + resp, err := s.interceptor( + func(ctx context.Context, req interface{}) (interface{}, error) { + typedReq, ok := req.(*AllowRequest) + if !ok { + return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") + } + return s.Ability.Allowed(ctx, typedReq) + }, + )(ctx, req) + if resp != nil { + typedResp, ok := resp.(*AllowReply) + if !ok { + return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") + } + return typedResp, err + } + return nil, err + } + } + + // Call service method + var respContent *AllowReply + func() { + defer ensurePanicResponses(ctx, resp, s.hooks) + respContent, err = handler(ctx, reqContent) + }() + + if err != nil { + s.writeError(ctx, resp, err) + return + } + if respContent == nil { + s.writeError(ctx, resp, twirp.InternalError("received a nil *AllowReply and nil error while calling Allowed. nil responses are not supported")) + return + } + + ctx = callResponsePrepared(ctx, s.hooks) + + respBytes, err := proto.Marshal(respContent) + if err != nil { + s.writeError(ctx, resp, wrapInternal(err, "failed to marshal proto response")) + return + } + + ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) + resp.Header().Set("Content-Type", "application/protobuf") + resp.Header().Set("Content-Length", strconv.Itoa(len(respBytes))) + resp.WriteHeader(http.StatusOK) + if n, err := resp.Write(respBytes); err != nil { + msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) + twerr := twirp.NewError(twirp.Unknown, msg) + ctx = callError(ctx, s.hooks, twerr) + } + callResponseSent(ctx, s.hooks) +} + +func (s *abilityServer) ServiceDescriptor() ([]byte, int) { + return twirpFileDescriptor0, 0 +} + +func (s *abilityServer) ProtocGenTwirpVersion() string { + return "v8.1.3" +} + +// PathPrefix returns the base service path, in the form: "//./" +// that is everything in a Twirp route except for the . This can be used for routing, +// for example to identify the requests that are targeted to this service in a mux. +func (s *abilityServer) PathPrefix() string { + return baseServicePath(s.pathPrefix, "authz.rpc", "Ability") +} + +// ===== +// Utils +// ===== + +// HTTPClient is the interface used by generated clients to send HTTP requests. +// It is fulfilled by *(net/http).Client, which is sufficient for most users. +// Users can provide their own implementation for special retry policies. +// +// HTTPClient implementations should not follow redirects. Redirects are +// automatically disabled if *(net/http).Client is passed to client +// constructors. See the withoutRedirects function in this file for more +// details. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// TwirpServer is the interface generated server structs will support: they're +// HTTP handlers with additional methods for accessing metadata about the +// service. Those accessors are a low-level API for building reflection tools. +// Most people can think of TwirpServers as just http.Handlers. +type TwirpServer interface { + http.Handler + + // ServiceDescriptor returns gzipped bytes describing the .proto file that + // this service was generated from. Once unzipped, the bytes can be + // unmarshalled as a + // google.golang.org/protobuf/types/descriptorpb.FileDescriptorProto. + // + // The returned integer is the index of this particular service within that + // FileDescriptorProto's 'Service' slice of ServiceDescriptorProtos. This is a + // low-level field, expected to be used for reflection. + ServiceDescriptor() ([]byte, int) + + // ProtocGenTwirpVersion is the semantic version string of the version of + // twirp used to generate this file. + ProtocGenTwirpVersion() string + + // PathPrefix returns the HTTP URL path prefix for all methods handled by this + // service. This can be used with an HTTP mux to route Twirp requests. + // The path prefix is in the form: "//./" + // that is, everything in a Twirp route except for the at the end. + PathPrefix() string +} + +func newServerOpts(opts []interface{}) *twirp.ServerOptions { + serverOpts := &twirp.ServerOptions{} + for _, opt := range opts { + switch o := opt.(type) { + case twirp.ServerOption: + o(serverOpts) + case *twirp.ServerHooks: // backwards compatibility, allow to specify hooks as an argument + twirp.WithServerHooks(o)(serverOpts) + case nil: // backwards compatibility, allow nil value for the argument + continue + default: + panic(fmt.Sprintf("Invalid option type %T, please use a twirp.ServerOption", o)) + } + } + return serverOpts +} + +// WriteError writes an HTTP response with a valid Twirp error format (code, msg, meta). +// Useful outside of the Twirp server (e.g. http middleware), but does not trigger hooks. +// If err is not a twirp.Error, it will get wrapped with twirp.InternalErrorWith(err) +func WriteError(resp http.ResponseWriter, err error) { + writeError(context.Background(), resp, err, nil) +} + +// writeError writes Twirp errors in the response and triggers hooks. +func writeError(ctx context.Context, resp http.ResponseWriter, err error, hooks *twirp.ServerHooks) { + // Convert to a twirp.Error. Non-twirp errors are converted to internal errors. + var twerr twirp.Error + if !errors.As(err, &twerr) { + twerr = twirp.InternalErrorWith(err) + } + + statusCode := twirp.ServerHTTPStatusFromErrorCode(twerr.Code()) + ctx = ctxsetters.WithStatusCode(ctx, statusCode) + ctx = callError(ctx, hooks, twerr) + + respBody := marshalErrorToJSON(twerr) + + resp.Header().Set("Content-Type", "application/json") // Error responses are always JSON + resp.Header().Set("Content-Length", strconv.Itoa(len(respBody))) + resp.WriteHeader(statusCode) // set HTTP status code and send response + + _, writeErr := resp.Write(respBody) + if writeErr != nil { + // We have three options here. We could log the error, call the Error + // hook, or just silently ignore the error. + // + // Logging is unacceptable because we don't have a user-controlled + // logger; writing out to stderr without permission is too rude. + // + // Calling the Error hook would confuse users: it would mean the Error + // hook got called twice for one request, which is likely to lead to + // duplicated log messages and metrics, no matter how well we document + // the behavior. + // + // Silently ignoring the error is our least-bad option. It's highly + // likely that the connection is broken and the original 'err' says + // so anyway. + _ = writeErr + } + + callResponseSent(ctx, hooks) +} + +// sanitizeBaseURL parses the the baseURL, and adds the "http" scheme if needed. +// If the URL is unparsable, the baseURL is returned unchanged. +func sanitizeBaseURL(baseURL string) string { + u, err := url.Parse(baseURL) + if err != nil { + return baseURL // invalid URL will fail later when making requests + } + if u.Scheme == "" { + u.Scheme = "http" + } + return u.String() +} + +// baseServicePath composes the path prefix for the service (without ). +// e.g.: baseServicePath("/twirp", "my.pkg", "MyService") +// +// returns => "/twirp/my.pkg.MyService/" +// +// e.g.: baseServicePath("", "", "MyService") +// +// returns => "/MyService/" +func baseServicePath(prefix, pkg, service string) string { + fullServiceName := service + if pkg != "" { + fullServiceName = pkg + "." + service + } + return path.Join("/", prefix, fullServiceName) + "/" +} + +// parseTwirpPath extracts path components form a valid Twirp route. +// Expected format: "[]/./" +// e.g.: prefix, pkgService, method := parseTwirpPath("/twirp/pkg.Svc/MakeHat") +func parseTwirpPath(path string) (string, string, string) { + parts := strings.Split(path, "/") + if len(parts) < 2 { + return "", "", "" + } + method := parts[len(parts)-1] + pkgService := parts[len(parts)-2] + prefix := strings.Join(parts[0:len(parts)-2], "/") + return prefix, pkgService, method +} + +// getCustomHTTPReqHeaders retrieves a copy of any headers that are set in +// a context through the twirp.WithHTTPRequestHeaders function. +// If there are no headers set, or if they have the wrong type, nil is returned. +func getCustomHTTPReqHeaders(ctx context.Context) http.Header { + header, ok := twirp.HTTPRequestHeaders(ctx) + if !ok || header == nil { + return nil + } + copied := make(http.Header) + for k, vv := range header { + if vv == nil { + copied[k] = nil + continue + } + copied[k] = make([]string, len(vv)) + copy(copied[k], vv) + } + return copied +} + +// newRequest makes an http.Request from a client, adding common headers. +func newRequest(ctx context.Context, url string, reqBody io.Reader, contentType string) (*http.Request, error) { + req, err := http.NewRequest("POST", url, reqBody) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if customHeader := getCustomHTTPReqHeaders(ctx); customHeader != nil { + req.Header = customHeader + } + req.Header.Set("Accept", contentType) + req.Header.Set("Content-Type", contentType) + req.Header.Set("Twirp-Version", "v8.1.3") + return req, nil +} + +// JSON serialization for errors +type twerrJSON struct { + Code string `json:"code"` + Msg string `json:"msg"` + Meta map[string]string `json:"meta,omitempty"` +} + +// marshalErrorToJSON returns JSON from a twirp.Error, that can be used as HTTP error response body. +// If serialization fails, it will use a descriptive Internal error instead. +func marshalErrorToJSON(twerr twirp.Error) []byte { + // make sure that msg is not too large + msg := twerr.Msg() + if len(msg) > 1e6 { + msg = msg[:1e6] + } + + tj := twerrJSON{ + Code: string(twerr.Code()), + Msg: msg, + Meta: twerr.MetaMap(), + } + + buf, err := json.Marshal(&tj) + if err != nil { + buf = []byte("{\"type\": \"" + twirp.Internal + "\", \"msg\": \"There was an error but it could not be serialized into JSON\"}") // fallback + } + + return buf +} + +// errorFromResponse builds a twirp.Error from a non-200 HTTP response. +// If the response has a valid serialized Twirp error, then it's returned. +// If not, the response status code is used to generate a similar twirp +// error. See twirpErrorFromIntermediary for more info on intermediary errors. +func errorFromResponse(resp *http.Response) twirp.Error { + statusCode := resp.StatusCode + statusText := http.StatusText(statusCode) + + if isHTTPRedirect(statusCode) { + // Unexpected redirect: it must be an error from an intermediary. + // Twirp clients don't follow redirects automatically, Twirp only handles + // POST requests, redirects should only happen on GET and HEAD requests. + location := resp.Header.Get("Location") + msg := fmt.Sprintf("unexpected HTTP status code %d %q received, Location=%q", statusCode, statusText, location) + return twirpErrorFromIntermediary(statusCode, msg, location) + } + + respBodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return wrapInternal(err, "failed to read server error response body") + } + + var tj twerrJSON + dec := json.NewDecoder(bytes.NewReader(respBodyBytes)) + dec.DisallowUnknownFields() + if err := dec.Decode(&tj); err != nil || tj.Code == "" { + // Invalid JSON response; it must be an error from an intermediary. + msg := fmt.Sprintf("Error from intermediary with HTTP status code %d %q", statusCode, statusText) + return twirpErrorFromIntermediary(statusCode, msg, string(respBodyBytes)) + } + + errorCode := twirp.ErrorCode(tj.Code) + if !twirp.IsValidErrorCode(errorCode) { + msg := "invalid type returned from server error response: " + tj.Code + return twirp.InternalError(msg).WithMeta("body", string(respBodyBytes)) + } + + twerr := twirp.NewError(errorCode, tj.Msg) + for k, v := range tj.Meta { + twerr = twerr.WithMeta(k, v) + } + return twerr +} + +// twirpErrorFromIntermediary maps HTTP errors from non-twirp sources to twirp errors. +// The mapping is similar to gRPC: https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md. +// Returned twirp Errors have some additional metadata for inspection. +func twirpErrorFromIntermediary(status int, msg string, bodyOrLocation string) twirp.Error { + var code twirp.ErrorCode + if isHTTPRedirect(status) { // 3xx + code = twirp.Internal + } else { + switch status { + case 400: // Bad Request + code = twirp.Internal + case 401: // Unauthorized + code = twirp.Unauthenticated + case 403: // Forbidden + code = twirp.PermissionDenied + case 404: // Not Found + code = twirp.BadRoute + case 429: // Too Many Requests + code = twirp.ResourceExhausted + case 502, 503, 504: // Bad Gateway, Service Unavailable, Gateway Timeout + code = twirp.Unavailable + default: // All other codes + code = twirp.Unknown + } + } + + twerr := twirp.NewError(code, msg) + twerr = twerr.WithMeta("http_error_from_intermediary", "true") // to easily know if this error was from intermediary + twerr = twerr.WithMeta("status_code", strconv.Itoa(status)) + if isHTTPRedirect(status) { + twerr = twerr.WithMeta("location", bodyOrLocation) + } else { + twerr = twerr.WithMeta("body", bodyOrLocation) + } + return twerr +} + +func isHTTPRedirect(status int) bool { + return status >= 300 && status <= 399 +} + +// wrapInternal wraps an error with a prefix as an Internal error. +// The original error cause is accessible by github.com/pkg/errors.Cause. +func wrapInternal(err error, prefix string) twirp.Error { + return twirp.InternalErrorWith(&wrappedError{prefix: prefix, cause: err}) +} + +type wrappedError struct { + prefix string + cause error +} + +func (e *wrappedError) Error() string { return e.prefix + ": " + e.cause.Error() } +func (e *wrappedError) Unwrap() error { return e.cause } // for go1.13 + errors.Is/As +func (e *wrappedError) Cause() error { return e.cause } // for github.com/pkg/errors + +// ensurePanicResponses makes sure that rpc methods causing a panic still result in a Twirp Internal +// error response (status 500), and error hooks are properly called with the panic wrapped as an error. +// The panic is re-raised so it can be handled normally with middleware. +func ensurePanicResponses(ctx context.Context, resp http.ResponseWriter, hooks *twirp.ServerHooks) { + if r := recover(); r != nil { + // Wrap the panic as an error so it can be passed to error hooks. + // The original error is accessible from error hooks, but not visible in the response. + err := errFromPanic(r) + twerr := &internalWithCause{msg: "Internal service panic", cause: err} + // Actually write the error + writeError(ctx, resp, twerr, hooks) + // If possible, flush the error to the wire. + f, ok := resp.(http.Flusher) + if ok { + f.Flush() + } + + panic(r) + } +} + +// errFromPanic returns the typed error if the recovered panic is an error, otherwise formats as error. +func errFromPanic(p interface{}) error { + if err, ok := p.(error); ok { + return err + } + return fmt.Errorf("panic: %v", p) +} + +// internalWithCause is a Twirp Internal error wrapping an original error cause, +// but the original error message is not exposed on Msg(). The original error +// can be checked with go1.13+ errors.Is/As, and also by (github.com/pkg/errors).Unwrap +type internalWithCause struct { + msg string + cause error +} + +func (e *internalWithCause) Unwrap() error { return e.cause } // for go1.13 + errors.Is/As +func (e *internalWithCause) Cause() error { return e.cause } // for github.com/pkg/errors +func (e *internalWithCause) Error() string { return e.msg + ": " + e.cause.Error() } +func (e *internalWithCause) Code() twirp.ErrorCode { return twirp.Internal } +func (e *internalWithCause) Msg() string { return e.msg } +func (e *internalWithCause) Meta(key string) string { return "" } +func (e *internalWithCause) MetaMap() map[string]string { return nil } +func (e *internalWithCause) WithMeta(key string, val string) twirp.Error { return e } + +// malformedRequestError is used when the twirp server cannot unmarshal a request +func malformedRequestError(msg string) twirp.Error { + return twirp.NewError(twirp.Malformed, msg) +} + +// badRouteError is used when the twirp server cannot route a request +func badRouteError(msg string, method, url string) twirp.Error { + err := twirp.NewError(twirp.BadRoute, msg) + err = err.WithMeta("twirp_invalid_route", method+" "+url) + return err +} + +// withoutRedirects makes sure that the POST request can not be redirected. +// The standard library will, by default, redirect requests (including POSTs) if it gets a 302 or +// 303 response, and also 301s in go1.8. It redirects by making a second request, changing the +// method to GET and removing the body. This produces very confusing error messages, so instead we +// set a redirect policy that always errors. This stops Go from executing the redirect. +// +// We have to be a little careful in case the user-provided http.Client has its own CheckRedirect +// policy - if so, we'll run through that policy first. +// +// Because this requires modifying the http.Client, we make a new copy of the client and return it. +func withoutRedirects(in *http.Client) *http.Client { + copy := *in + copy.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if in.CheckRedirect != nil { + // Run the input's redirect if it exists, in case it has side effects, but ignore any error it + // returns, since we want to use ErrUseLastResponse. + err := in.CheckRedirect(req, via) + _ = err // Silly, but this makes sure generated code passes errcheck -blank, which some people use. + } + return http.ErrUseLastResponse + } + return © +} + +// doProtobufRequest makes a Protobuf request to the remote Twirp service. +func doProtobufRequest(ctx context.Context, client HTTPClient, hooks *twirp.ClientHooks, url string, in, out proto.Message) (_ context.Context, err error) { + reqBodyBytes, err := proto.Marshal(in) + if err != nil { + return ctx, wrapInternal(err, "failed to marshal proto request") + } + reqBody := bytes.NewBuffer(reqBodyBytes) + if err = ctx.Err(); err != nil { + return ctx, wrapInternal(err, "aborted because context was done") + } + + req, err := newRequest(ctx, url, reqBody, "application/protobuf") + if err != nil { + return ctx, wrapInternal(err, "could not build request") + } + ctx, err = callClientRequestPrepared(ctx, hooks, req) + if err != nil { + return ctx, err + } + + req = req.WithContext(ctx) + resp, err := client.Do(req) + if err != nil { + return ctx, wrapInternal(err, "failed to do request") + } + defer func() { _ = resp.Body.Close() }() + + if err = ctx.Err(); err != nil { + return ctx, wrapInternal(err, "aborted because context was done") + } + + if resp.StatusCode != 200 { + return ctx, errorFromResponse(resp) + } + + respBodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return ctx, wrapInternal(err, "failed to read response body") + } + if err = ctx.Err(); err != nil { + return ctx, wrapInternal(err, "aborted because context was done") + } + + if err = proto.Unmarshal(respBodyBytes, out); err != nil { + return ctx, wrapInternal(err, "failed to unmarshal proto response") + } + return ctx, nil +} + +// doJSONRequest makes a JSON request to the remote Twirp service. +func doJSONRequest(ctx context.Context, client HTTPClient, hooks *twirp.ClientHooks, url string, in, out proto.Message) (_ context.Context, err error) { + marshaler := &protojson.MarshalOptions{UseProtoNames: true} + reqBytes, err := marshaler.Marshal(in) + if err != nil { + return ctx, wrapInternal(err, "failed to marshal json request") + } + if err = ctx.Err(); err != nil { + return ctx, wrapInternal(err, "aborted because context was done") + } + + req, err := newRequest(ctx, url, bytes.NewReader(reqBytes), "application/json") + if err != nil { + return ctx, wrapInternal(err, "could not build request") + } + ctx, err = callClientRequestPrepared(ctx, hooks, req) + if err != nil { + return ctx, err + } + + req = req.WithContext(ctx) + resp, err := client.Do(req) + if err != nil { + return ctx, wrapInternal(err, "failed to do request") + } + + defer func() { + cerr := resp.Body.Close() + if err == nil && cerr != nil { + err = wrapInternal(cerr, "failed to close response body") + } + }() + + if err = ctx.Err(); err != nil { + return ctx, wrapInternal(err, "aborted because context was done") + } + + if resp.StatusCode != 200 { + return ctx, errorFromResponse(resp) + } + + d := json.NewDecoder(resp.Body) + rawRespBody := json.RawMessage{} + if err := d.Decode(&rawRespBody); err != nil { + return ctx, wrapInternal(err, "failed to unmarshal json response") + } + unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} + if err = unmarshaler.Unmarshal(rawRespBody, out); err != nil { + return ctx, wrapInternal(err, "failed to unmarshal json response") + } + if err = ctx.Err(); err != nil { + return ctx, wrapInternal(err, "aborted because context was done") + } + return ctx, nil +} + +// Call twirp.ServerHooks.RequestReceived if the hook is available +func callRequestReceived(ctx context.Context, h *twirp.ServerHooks) (context.Context, error) { + if h == nil || h.RequestReceived == nil { + return ctx, nil + } + return h.RequestReceived(ctx) +} + +// Call twirp.ServerHooks.RequestRouted if the hook is available +func callRequestRouted(ctx context.Context, h *twirp.ServerHooks) (context.Context, error) { + if h == nil || h.RequestRouted == nil { + return ctx, nil + } + return h.RequestRouted(ctx) +} + +// Call twirp.ServerHooks.ResponsePrepared if the hook is available +func callResponsePrepared(ctx context.Context, h *twirp.ServerHooks) context.Context { + if h == nil || h.ResponsePrepared == nil { + return ctx + } + return h.ResponsePrepared(ctx) +} + +// Call twirp.ServerHooks.ResponseSent if the hook is available +func callResponseSent(ctx context.Context, h *twirp.ServerHooks) { + if h == nil || h.ResponseSent == nil { + return + } + h.ResponseSent(ctx) +} + +// Call twirp.ServerHooks.Error if the hook is available +func callError(ctx context.Context, h *twirp.ServerHooks, err twirp.Error) context.Context { + if h == nil || h.Error == nil { + return ctx + } + return h.Error(ctx, err) +} + +func callClientResponseReceived(ctx context.Context, h *twirp.ClientHooks) { + if h == nil || h.ResponseReceived == nil { + return + } + h.ResponseReceived(ctx) +} + +func callClientRequestPrepared(ctx context.Context, h *twirp.ClientHooks, req *http.Request) (context.Context, error) { + if h == nil || h.RequestPrepared == nil { + return ctx, nil + } + return h.RequestPrepared(ctx, req) +} + +func callClientError(ctx context.Context, h *twirp.ClientHooks, err twirp.Error) { + if h == nil || h.Error == nil { + return + } + h.Error(ctx, err) +} + +var twirpFileDescriptor0 = []byte{ + // 196 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0x4d, 0x4c, 0xca, 0xcc, + 0xc9, 0x2c, 0xa9, 0xd4, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xe2, 0x4c, 0x2c, 0x2d, 0xc9, 0xa8, + 0xd2, 0x2b, 0x2a, 0x48, 0x56, 0x4a, 0xe1, 0xe2, 0x71, 0xcc, 0xc9, 0xc9, 0x2f, 0x0f, 0x4a, 0x2d, + 0x2c, 0x4d, 0x2d, 0x2e, 0x11, 0x92, 0xe0, 0x62, 0x2f, 0x2e, 0x4d, 0xca, 0x4a, 0x4d, 0x2e, 0x91, + 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x0c, 0x82, 0x71, 0x85, 0xe4, 0xb8, 0xb8, 0x0a, 0x52, 0x8b, 0x72, + 0x33, 0x8b, 0x8b, 0x33, 0xf3, 0xf3, 0x24, 0x98, 0xc0, 0x92, 0x48, 0x22, 0x42, 0x52, 0x5c, 0x1c, + 0x45, 0xa9, 0xc5, 0xf9, 0xa5, 0x45, 0xc9, 0xa9, 0x12, 0xcc, 0x60, 0x59, 0x38, 0x5f, 0x49, 0x85, + 0x8b, 0x0b, 0x6a, 0x4b, 0x41, 0x4e, 0xa5, 0x90, 0x18, 0x17, 0x5b, 0x51, 0x6a, 0x71, 0x69, 0x0e, + 0xc4, 0x0a, 0x8e, 0x20, 0x28, 0xcf, 0xc8, 0x8d, 0x8b, 0xdd, 0x11, 0xe2, 0x4e, 0x21, 0x6b, 0x2e, + 0x76, 0xb0, 0x86, 0xd4, 0x14, 0x21, 0x71, 0x3d, 0xb8, 0x6b, 0xf5, 0x90, 0x9d, 0x2a, 0x25, 0x8a, + 0x29, 0x51, 0x90, 0x53, 0xa9, 0xc4, 0xe0, 0xc4, 0x19, 0xc5, 0x5e, 0x90, 0x9d, 0xae, 0x5f, 0x54, + 0x90, 0x9c, 0xc4, 0x06, 0xf6, 0xb0, 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0x72, 0x35, 0x46, 0x7c, + 0x01, 0x01, 0x00, 0x00, +} diff --git a/pkg/rpc/ability_grpc.pb.go b/pkg/rpc/ability_grpc.pb.go deleted file mode 100644 index 4d74cc41..00000000 --- a/pkg/rpc/ability_grpc.pb.go +++ /dev/null @@ -1,121 +0,0 @@ -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.5.1 -// - protoc v3.19.6 -// source: ability.proto - -package rpc - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.64.0 or later. -const _ = grpc.SupportPackageIsVersion9 - -const ( - Ability_Allowed_FullMethodName = "/authx.rpc.Ability/Allowed" -) - -// AbilityClient is the client API for Ability service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type AbilityClient interface { - Allowed(ctx context.Context, in *AllowRequest, opts ...grpc.CallOption) (*AllowReply, error) -} - -type abilityClient struct { - cc grpc.ClientConnInterface -} - -func NewAbilityClient(cc grpc.ClientConnInterface) AbilityClient { - return &abilityClient{cc} -} - -func (c *abilityClient) Allowed(ctx context.Context, in *AllowRequest, opts ...grpc.CallOption) (*AllowReply, error) { - cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) - out := new(AllowReply) - err := c.cc.Invoke(ctx, Ability_Allowed_FullMethodName, in, out, cOpts...) - if err != nil { - return nil, err - } - return out, nil -} - -// AbilityServer is the server API for Ability service. -// All implementations must embed UnimplementedAbilityServer -// for forward compatibility. -type AbilityServer interface { - Allowed(context.Context, *AllowRequest) (*AllowReply, error) - mustEmbedUnimplementedAbilityServer() -} - -// UnimplementedAbilityServer must be embedded to have -// forward compatible implementations. -// -// NOTE: this should be embedded by value instead of pointer to avoid a nil -// pointer dereference when methods are called. -type UnimplementedAbilityServer struct{} - -func (UnimplementedAbilityServer) Allowed(context.Context, *AllowRequest) (*AllowReply, error) { - return nil, status.Errorf(codes.Unimplemented, "method Allowed not implemented") -} -func (UnimplementedAbilityServer) mustEmbedUnimplementedAbilityServer() {} -func (UnimplementedAbilityServer) testEmbeddedByValue() {} - -// UnsafeAbilityServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to AbilityServer will -// result in compilation errors. -type UnsafeAbilityServer interface { - mustEmbedUnimplementedAbilityServer() -} - -func RegisterAbilityServer(s grpc.ServiceRegistrar, srv AbilityServer) { - // If the following call pancis, it indicates UnimplementedAbilityServer was - // embedded by pointer and is nil. This will cause panics if an - // unimplemented method is ever invoked, so we test this at initialization - // time to prevent it from happening at runtime later due to I/O. - if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { - t.testEmbeddedByValue() - } - s.RegisterService(&Ability_ServiceDesc, srv) -} - -func _Ability_Allowed_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(AllowRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(AbilityServer).Allowed(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: Ability_Allowed_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(AbilityServer).Allowed(ctx, req.(*AllowRequest)) - } - return interceptor(ctx, in, info, handler) -} - -// Ability_ServiceDesc is the grpc.ServiceDesc for Ability service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var Ability_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "authx.rpc.Ability", - HandlerType: (*AbilityServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "Allowed", - Handler: _Ability_Allowed_Handler, - }, - }, - Streams: []grpc.StreamDesc{}, - Metadata: "ability.proto", -} diff --git a/pkg/rpc/ability_service.go b/pkg/rpc/ability_service.go index 18327d52..db2e8fab 100644 --- a/pkg/rpc/ability_service.go +++ b/pkg/rpc/ability_service.go @@ -4,12 +4,11 @@ import ( context "context" "github.com/cedar-policy/cedar-go" - "gitlab.com/mokhax/spike/pkg/gid" - "gitlab.com/mokhax/spike/pkg/policies" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/authz.d/pkg/gid" + "gitlab.com/gitlab-org/software-supply-chain-security/authorization/authz.d/pkg/policies" ) type AbilityService struct { - UnimplementedAbilityServer } func NewAbilityService() *AbilityService { diff --git a/pkg/rpc/gitlab.com/mokhax/spike/pkg/rpc/ability.twirp.go b/pkg/rpc/gitlab.com/mokhax/spike/pkg/rpc/ability.twirp.go deleted file mode 100644 index ea2c3d17..00000000 --- a/pkg/rpc/gitlab.com/mokhax/spike/pkg/rpc/ability.twirp.go +++ /dev/null @@ -1,1105 +0,0 @@ -// Code generated by protoc-gen-twirp v8.1.3, DO NOT EDIT. -// source: ability.proto - -package rpc - -import context "context" -import fmt "fmt" -import http "net/http" -import io "io" -import json "encoding/json" -import strconv "strconv" -import strings "strings" - -import protojson "google.golang.org/protobuf/encoding/protojson" -import proto "google.golang.org/protobuf/proto" -import twirp "github.com/twitchtv/twirp" -import ctxsetters "github.com/twitchtv/twirp/ctxsetters" - -import bytes "bytes" -import errors "errors" -import path "path" -import url "net/url" - -// Version compatibility assertion. -// If the constant is not defined in the package, that likely means -// the package needs to be updated to work with this generated code. -// See https://twitchtv.github.io/twirp/docs/version_matrix.html -const _ = twirp.TwirpPackageMinVersion_8_1_0 - -// ================= -// Ability Interface -// ================= - -type Ability interface { - Allowed(context.Context, *AllowRequest) (*AllowReply, error) -} - -// ======================= -// Ability Protobuf Client -// ======================= - -type abilityProtobufClient struct { - client HTTPClient - urls [1]string - interceptor twirp.Interceptor - opts twirp.ClientOptions -} - -// NewAbilityProtobufClient creates a Protobuf client that implements the Ability interface. -// It communicates using Protobuf and can be configured with a custom HTTPClient. -func NewAbilityProtobufClient(baseURL string, client HTTPClient, opts ...twirp.ClientOption) Ability { - if c, ok := client.(*http.Client); ok { - client = withoutRedirects(c) - } - - clientOpts := twirp.ClientOptions{} - for _, o := range opts { - o(&clientOpts) - } - - // Using ReadOpt allows backwards and forwards compatibility with new options in the future - literalURLs := false - _ = clientOpts.ReadOpt("literalURLs", &literalURLs) - var pathPrefix string - if ok := clientOpts.ReadOpt("pathPrefix", &pathPrefix); !ok { - pathPrefix = "/twirp" // default prefix - } - - // Build method URLs: []/./ - serviceURL := sanitizeBaseURL(baseURL) - serviceURL += baseServicePath(pathPrefix, "authx.rpc", "Ability") - urls := [1]string{ - serviceURL + "Allowed", - } - - return &abilityProtobufClient{ - client: client, - urls: urls, - interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), - opts: clientOpts, - } -} - -func (c *abilityProtobufClient) Allowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { - ctx = ctxsetters.WithPackageName(ctx, "authx.rpc") - ctx = ctxsetters.WithServiceName(ctx, "Ability") - ctx = ctxsetters.WithMethodName(ctx, "Allowed") - caller := c.callAllowed - if c.interceptor != nil { - caller = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { - resp, err := c.interceptor( - func(ctx context.Context, req interface{}) (interface{}, error) { - typedReq, ok := req.(*AllowRequest) - if !ok { - return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") - } - return c.callAllowed(ctx, typedReq) - }, - )(ctx, req) - if resp != nil { - typedResp, ok := resp.(*AllowReply) - if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") - } - return typedResp, err - } - return nil, err - } - } - return caller(ctx, in) -} - -func (c *abilityProtobufClient) callAllowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { - out := new(AllowReply) - ctx, err := doProtobufRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) - if err != nil { - twerr, ok := err.(twirp.Error) - if !ok { - twerr = twirp.InternalErrorWith(err) - } - callClientError(ctx, c.opts.Hooks, twerr) - return nil, err - } - - callClientResponseReceived(ctx, c.opts.Hooks) - - return out, nil -} - -// =================== -// Ability JSON Client -// =================== - -type abilityJSONClient struct { - client HTTPClient - urls [1]string - interceptor twirp.Interceptor - opts twirp.ClientOptions -} - -// NewAbilityJSONClient creates a JSON client that implements the Ability interface. -// It communicates using JSON and can be configured with a custom HTTPClient. -func NewAbilityJSONClient(baseURL string, client HTTPClient, opts ...twirp.ClientOption) Ability { - if c, ok := client.(*http.Client); ok { - client = withoutRedirects(c) - } - - clientOpts := twirp.ClientOptions{} - for _, o := range opts { - o(&clientOpts) - } - - // Using ReadOpt allows backwards and forwards compatibility with new options in the future - literalURLs := false - _ = clientOpts.ReadOpt("literalURLs", &literalURLs) - var pathPrefix string - if ok := clientOpts.ReadOpt("pathPrefix", &pathPrefix); !ok { - pathPrefix = "/twirp" // default prefix - } - - // Build method URLs: []/./ - serviceURL := sanitizeBaseURL(baseURL) - serviceURL += baseServicePath(pathPrefix, "authx.rpc", "Ability") - urls := [1]string{ - serviceURL + "Allowed", - } - - return &abilityJSONClient{ - client: client, - urls: urls, - interceptor: twirp.ChainInterceptors(clientOpts.Interceptors...), - opts: clientOpts, - } -} - -func (c *abilityJSONClient) Allowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { - ctx = ctxsetters.WithPackageName(ctx, "authx.rpc") - ctx = ctxsetters.WithServiceName(ctx, "Ability") - ctx = ctxsetters.WithMethodName(ctx, "Allowed") - caller := c.callAllowed - if c.interceptor != nil { - caller = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { - resp, err := c.interceptor( - func(ctx context.Context, req interface{}) (interface{}, error) { - typedReq, ok := req.(*AllowRequest) - if !ok { - return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") - } - return c.callAllowed(ctx, typedReq) - }, - )(ctx, req) - if resp != nil { - typedResp, ok := resp.(*AllowReply) - if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") - } - return typedResp, err - } - return nil, err - } - } - return caller(ctx, in) -} - -func (c *abilityJSONClient) callAllowed(ctx context.Context, in *AllowRequest) (*AllowReply, error) { - out := new(AllowReply) - ctx, err := doJSONRequest(ctx, c.client, c.opts.Hooks, c.urls[0], in, out) - if err != nil { - twerr, ok := err.(twirp.Error) - if !ok { - twerr = twirp.InternalErrorWith(err) - } - callClientError(ctx, c.opts.Hooks, twerr) - return nil, err - } - - callClientResponseReceived(ctx, c.opts.Hooks) - - return out, nil -} - -// ====================== -// Ability Server Handler -// ====================== - -type abilityServer struct { - Ability - interceptor twirp.Interceptor - hooks *twirp.ServerHooks - pathPrefix string // prefix for routing - jsonSkipDefaults bool // do not include unpopulated fields (default values) in the response - jsonCamelCase bool // JSON fields are serialized as lowerCamelCase rather than keeping the original proto names -} - -// NewAbilityServer builds a TwirpServer that can be used as an http.Handler to handle -// HTTP requests that are routed to the right method in the provided svc implementation. -// The opts are twirp.ServerOption modifiers, for example twirp.WithServerHooks(hooks). -func NewAbilityServer(svc Ability, opts ...interface{}) TwirpServer { - serverOpts := newServerOpts(opts) - - // Using ReadOpt allows backwards and forwards compatibility with new options in the future - jsonSkipDefaults := false - _ = serverOpts.ReadOpt("jsonSkipDefaults", &jsonSkipDefaults) - jsonCamelCase := false - _ = serverOpts.ReadOpt("jsonCamelCase", &jsonCamelCase) - var pathPrefix string - if ok := serverOpts.ReadOpt("pathPrefix", &pathPrefix); !ok { - pathPrefix = "/twirp" // default prefix - } - - return &abilityServer{ - Ability: svc, - hooks: serverOpts.Hooks, - interceptor: twirp.ChainInterceptors(serverOpts.Interceptors...), - pathPrefix: pathPrefix, - jsonSkipDefaults: jsonSkipDefaults, - jsonCamelCase: jsonCamelCase, - } -} - -// writeError writes an HTTP response with a valid Twirp error format, and triggers hooks. -// If err is not a twirp.Error, it will get wrapped with twirp.InternalErrorWith(err) -func (s *abilityServer) writeError(ctx context.Context, resp http.ResponseWriter, err error) { - writeError(ctx, resp, err, s.hooks) -} - -// handleRequestBodyError is used to handle error when the twirp server cannot read request -func (s *abilityServer) handleRequestBodyError(ctx context.Context, resp http.ResponseWriter, msg string, err error) { - if context.Canceled == ctx.Err() { - s.writeError(ctx, resp, twirp.NewError(twirp.Canceled, "failed to read request: context canceled")) - return - } - if context.DeadlineExceeded == ctx.Err() { - s.writeError(ctx, resp, twirp.NewError(twirp.DeadlineExceeded, "failed to read request: deadline exceeded")) - return - } - s.writeError(ctx, resp, twirp.WrapError(malformedRequestError(msg), err)) -} - -// AbilityPathPrefix is a convenience constant that may identify URL paths. -// Should be used with caution, it only matches routes generated by Twirp Go clients, -// with the default "/twirp" prefix and default CamelCase service and method names. -// More info: https://twitchtv.github.io/twirp/docs/routing.html -const AbilityPathPrefix = "/twirp/authx.rpc.Ability/" - -func (s *abilityServer) ServeHTTP(resp http.ResponseWriter, req *http.Request) { - ctx := req.Context() - ctx = ctxsetters.WithPackageName(ctx, "authx.rpc") - ctx = ctxsetters.WithServiceName(ctx, "Ability") - ctx = ctxsetters.WithResponseWriter(ctx, resp) - - var err error - ctx, err = callRequestReceived(ctx, s.hooks) - if err != nil { - s.writeError(ctx, resp, err) - return - } - - if req.Method != "POST" { - msg := fmt.Sprintf("unsupported method %q (only POST is allowed)", req.Method) - s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) - return - } - - // Verify path format: []/./ - prefix, pkgService, method := parseTwirpPath(req.URL.Path) - if pkgService != "authx.rpc.Ability" { - msg := fmt.Sprintf("no handler for path %q", req.URL.Path) - s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) - return - } - if prefix != s.pathPrefix { - msg := fmt.Sprintf("invalid path prefix %q, expected %q, on path %q", prefix, s.pathPrefix, req.URL.Path) - s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) - return - } - - switch method { - case "Allowed": - s.serveAllowed(ctx, resp, req) - return - default: - msg := fmt.Sprintf("no handler for path %q", req.URL.Path) - s.writeError(ctx, resp, badRouteError(msg, req.Method, req.URL.Path)) - return - } -} - -func (s *abilityServer) serveAllowed(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - header := req.Header.Get("Content-Type") - i := strings.Index(header, ";") - if i == -1 { - i = len(header) - } - switch strings.TrimSpace(strings.ToLower(header[:i])) { - case "application/json": - s.serveAllowedJSON(ctx, resp, req) - case "application/protobuf": - s.serveAllowedProtobuf(ctx, resp, req) - default: - msg := fmt.Sprintf("unexpected Content-Type: %q", req.Header.Get("Content-Type")) - twerr := badRouteError(msg, req.Method, req.URL.Path) - s.writeError(ctx, resp, twerr) - } -} - -func (s *abilityServer) serveAllowedJSON(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var err error - ctx = ctxsetters.WithMethodName(ctx, "Allowed") - ctx, err = callRequestRouted(ctx, s.hooks) - if err != nil { - s.writeError(ctx, resp, err) - return - } - - d := json.NewDecoder(req.Body) - rawReqBody := json.RawMessage{} - if err := d.Decode(&rawReqBody); err != nil { - s.handleRequestBodyError(ctx, resp, "the json request could not be decoded", err) - return - } - reqContent := new(AllowRequest) - unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} - if err = unmarshaler.Unmarshal(rawReqBody, reqContent); err != nil { - s.handleRequestBodyError(ctx, resp, "the json request could not be decoded", err) - return - } - - handler := s.Ability.Allowed - if s.interceptor != nil { - handler = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { - resp, err := s.interceptor( - func(ctx context.Context, req interface{}) (interface{}, error) { - typedReq, ok := req.(*AllowRequest) - if !ok { - return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") - } - return s.Ability.Allowed(ctx, typedReq) - }, - )(ctx, req) - if resp != nil { - typedResp, ok := resp.(*AllowReply) - if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") - } - return typedResp, err - } - return nil, err - } - } - - // Call service method - var respContent *AllowReply - func() { - defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = handler(ctx, reqContent) - }() - - if err != nil { - s.writeError(ctx, resp, err) - return - } - if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *AllowReply and nil error while calling Allowed. nil responses are not supported")) - return - } - - ctx = callResponsePrepared(ctx, s.hooks) - - marshaler := &protojson.MarshalOptions{UseProtoNames: !s.jsonCamelCase, EmitUnpopulated: !s.jsonSkipDefaults} - respBytes, err := marshaler.Marshal(respContent) - if err != nil { - s.writeError(ctx, resp, wrapInternal(err, "failed to marshal json response")) - return - } - - ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) - resp.Header().Set("Content-Type", "application/json") - resp.Header().Set("Content-Length", strconv.Itoa(len(respBytes))) - resp.WriteHeader(http.StatusOK) - - if n, err := resp.Write(respBytes); err != nil { - msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) - twerr := twirp.NewError(twirp.Unknown, msg) - ctx = callError(ctx, s.hooks, twerr) - } - callResponseSent(ctx, s.hooks) -} - -func (s *abilityServer) serveAllowedProtobuf(ctx context.Context, resp http.ResponseWriter, req *http.Request) { - var err error - ctx = ctxsetters.WithMethodName(ctx, "Allowed") - ctx, err = callRequestRouted(ctx, s.hooks) - if err != nil { - s.writeError(ctx, resp, err) - return - } - - buf, err := io.ReadAll(req.Body) - if err != nil { - s.handleRequestBodyError(ctx, resp, "failed to read request body", err) - return - } - reqContent := new(AllowRequest) - if err = proto.Unmarshal(buf, reqContent); err != nil { - s.writeError(ctx, resp, malformedRequestError("the protobuf request could not be decoded")) - return - } - - handler := s.Ability.Allowed - if s.interceptor != nil { - handler = func(ctx context.Context, req *AllowRequest) (*AllowReply, error) { - resp, err := s.interceptor( - func(ctx context.Context, req interface{}) (interface{}, error) { - typedReq, ok := req.(*AllowRequest) - if !ok { - return nil, twirp.InternalError("failed type assertion req.(*AllowRequest) when calling interceptor") - } - return s.Ability.Allowed(ctx, typedReq) - }, - )(ctx, req) - if resp != nil { - typedResp, ok := resp.(*AllowReply) - if !ok { - return nil, twirp.InternalError("failed type assertion resp.(*AllowReply) when calling interceptor") - } - return typedResp, err - } - return nil, err - } - } - - // Call service method - var respContent *AllowReply - func() { - defer ensurePanicResponses(ctx, resp, s.hooks) - respContent, err = handler(ctx, reqContent) - }() - - if err != nil { - s.writeError(ctx, resp, err) - return - } - if respContent == nil { - s.writeError(ctx, resp, twirp.InternalError("received a nil *AllowReply and nil error while calling Allowed. nil responses are not supported")) - return - } - - ctx = callResponsePrepared(ctx, s.hooks) - - respBytes, err := proto.Marshal(respContent) - if err != nil { - s.writeError(ctx, resp, wrapInternal(err, "failed to marshal proto response")) - return - } - - ctx = ctxsetters.WithStatusCode(ctx, http.StatusOK) - resp.Header().Set("Content-Type", "application/protobuf") - resp.Header().Set("Content-Length", strconv.Itoa(len(respBytes))) - resp.WriteHeader(http.StatusOK) - if n, err := resp.Write(respBytes); err != nil { - msg := fmt.Sprintf("failed to write response, %d of %d bytes written: %s", n, len(respBytes), err.Error()) - twerr := twirp.NewError(twirp.Unknown, msg) - ctx = callError(ctx, s.hooks, twerr) - } - callResponseSent(ctx, s.hooks) -} - -func (s *abilityServer) ServiceDescriptor() ([]byte, int) { - return twirpFileDescriptor0, 0 -} - -func (s *abilityServer) ProtocGenTwirpVersion() string { - return "v8.1.3" -} - -// PathPrefix returns the base service path, in the form: "//./" -// that is everything in a Twirp route except for the . This can be used for routing, -// for example to identify the requests that are targeted to this service in a mux. -func (s *abilityServer) PathPrefix() string { - return baseServicePath(s.pathPrefix, "authx.rpc", "Ability") -} - -// ===== -// Utils -// ===== - -// HTTPClient is the interface used by generated clients to send HTTP requests. -// It is fulfilled by *(net/http).Client, which is sufficient for most users. -// Users can provide their own implementation for special retry policies. -// -// HTTPClient implementations should not follow redirects. Redirects are -// automatically disabled if *(net/http).Client is passed to client -// constructors. See the withoutRedirects function in this file for more -// details. -type HTTPClient interface { - Do(req *http.Request) (*http.Response, error) -} - -// TwirpServer is the interface generated server structs will support: they're -// HTTP handlers with additional methods for accessing metadata about the -// service. Those accessors are a low-level API for building reflection tools. -// Most people can think of TwirpServers as just http.Handlers. -type TwirpServer interface { - http.Handler - - // ServiceDescriptor returns gzipped bytes describing the .proto file that - // this service was generated from. Once unzipped, the bytes can be - // unmarshalled as a - // google.golang.org/protobuf/types/descriptorpb.FileDescriptorProto. - // - // The returned integer is the index of this particular service within that - // FileDescriptorProto's 'Service' slice of ServiceDescriptorProtos. This is a - // low-level field, expected to be used for reflection. - ServiceDescriptor() ([]byte, int) - - // ProtocGenTwirpVersion is the semantic version string of the version of - // twirp used to generate this file. - ProtocGenTwirpVersion() string - - // PathPrefix returns the HTTP URL path prefix for all methods handled by this - // service. This can be used with an HTTP mux to route Twirp requests. - // The path prefix is in the form: "//./" - // that is, everything in a Twirp route except for the at the end. - PathPrefix() string -} - -func newServerOpts(opts []interface{}) *twirp.ServerOptions { - serverOpts := &twirp.ServerOptions{} - for _, opt := range opts { - switch o := opt.(type) { - case twirp.ServerOption: - o(serverOpts) - case *twirp.ServerHooks: // backwards compatibility, allow to specify hooks as an argument - twirp.WithServerHooks(o)(serverOpts) - case nil: // backwards compatibility, allow nil value for the argument - continue - default: - panic(fmt.Sprintf("Invalid option type %T, please use a twirp.ServerOption", o)) - } - } - return serverOpts -} - -// WriteError writes an HTTP response with a valid Twirp error format (code, msg, meta). -// Useful outside of the Twirp server (e.g. http middleware), but does not trigger hooks. -// If err is not a twirp.Error, it will get wrapped with twirp.InternalErrorWith(err) -func WriteError(resp http.ResponseWriter, err error) { - writeError(context.Background(), resp, err, nil) -} - -// writeError writes Twirp errors in the response and triggers hooks. -func writeError(ctx context.Context, resp http.ResponseWriter, err error, hooks *twirp.ServerHooks) { - // Convert to a twirp.Error. Non-twirp errors are converted to internal errors. - var twerr twirp.Error - if !errors.As(err, &twerr) { - twerr = twirp.InternalErrorWith(err) - } - - statusCode := twirp.ServerHTTPStatusFromErrorCode(twerr.Code()) - ctx = ctxsetters.WithStatusCode(ctx, statusCode) - ctx = callError(ctx, hooks, twerr) - - respBody := marshalErrorToJSON(twerr) - - resp.Header().Set("Content-Type", "application/json") // Error responses are always JSON - resp.Header().Set("Content-Length", strconv.Itoa(len(respBody))) - resp.WriteHeader(statusCode) // set HTTP status code and send response - - _, writeErr := resp.Write(respBody) - if writeErr != nil { - // We have three options here. We could log the error, call the Error - // hook, or just silently ignore the error. - // - // Logging is unacceptable because we don't have a user-controlled - // logger; writing out to stderr without permission is too rude. - // - // Calling the Error hook would confuse users: it would mean the Error - // hook got called twice for one request, which is likely to lead to - // duplicated log messages and metrics, no matter how well we document - // the behavior. - // - // Silently ignoring the error is our least-bad option. It's highly - // likely that the connection is broken and the original 'err' says - // so anyway. - _ = writeErr - } - - callResponseSent(ctx, hooks) -} - -// sanitizeBaseURL parses the the baseURL, and adds the "http" scheme if needed. -// If the URL is unparsable, the baseURL is returned unchanged. -func sanitizeBaseURL(baseURL string) string { - u, err := url.Parse(baseURL) - if err != nil { - return baseURL // invalid URL will fail later when making requests - } - if u.Scheme == "" { - u.Scheme = "http" - } - return u.String() -} - -// baseServicePath composes the path prefix for the service (without ). -// e.g.: baseServicePath("/twirp", "my.pkg", "MyService") -// -// returns => "/twirp/my.pkg.MyService/" -// -// e.g.: baseServicePath("", "", "MyService") -// -// returns => "/MyService/" -func baseServicePath(prefix, pkg, service string) string { - fullServiceName := service - if pkg != "" { - fullServiceName = pkg + "." + service - } - return path.Join("/", prefix, fullServiceName) + "/" -} - -// parseTwirpPath extracts path components form a valid Twirp route. -// Expected format: "[]/./" -// e.g.: prefix, pkgService, method := parseTwirpPath("/twirp/pkg.Svc/MakeHat") -func parseTwirpPath(path string) (string, string, string) { - parts := strings.Split(path, "/") - if len(parts) < 2 { - return "", "", "" - } - method := parts[len(parts)-1] - pkgService := parts[len(parts)-2] - prefix := strings.Join(parts[0:len(parts)-2], "/") - return prefix, pkgService, method -} - -// getCustomHTTPReqHeaders retrieves a copy of any headers that are set in -// a context through the twirp.WithHTTPRequestHeaders function. -// If there are no headers set, or if they have the wrong type, nil is returned. -func getCustomHTTPReqHeaders(ctx context.Context) http.Header { - header, ok := twirp.HTTPRequestHeaders(ctx) - if !ok || header == nil { - return nil - } - copied := make(http.Header) - for k, vv := range header { - if vv == nil { - copied[k] = nil - continue - } - copied[k] = make([]string, len(vv)) - copy(copied[k], vv) - } - return copied -} - -// newRequest makes an http.Request from a client, adding common headers. -func newRequest(ctx context.Context, url string, reqBody io.Reader, contentType string) (*http.Request, error) { - req, err := http.NewRequest("POST", url, reqBody) - if err != nil { - return nil, err - } - req = req.WithContext(ctx) - if customHeader := getCustomHTTPReqHeaders(ctx); customHeader != nil { - req.Header = customHeader - } - req.Header.Set("Accept", contentType) - req.Header.Set("Content-Type", contentType) - req.Header.Set("Twirp-Version", "v8.1.3") - return req, nil -} - -// JSON serialization for errors -type twerrJSON struct { - Code string `json:"code"` - Msg string `json:"msg"` - Meta map[string]string `json:"meta,omitempty"` -} - -// marshalErrorToJSON returns JSON from a twirp.Error, that can be used as HTTP error response body. -// If serialization fails, it will use a descriptive Internal error instead. -func marshalErrorToJSON(twerr twirp.Error) []byte { - // make sure that msg is not too large - msg := twerr.Msg() - if len(msg) > 1e6 { - msg = msg[:1e6] - } - - tj := twerrJSON{ - Code: string(twerr.Code()), - Msg: msg, - Meta: twerr.MetaMap(), - } - - buf, err := json.Marshal(&tj) - if err != nil { - buf = []byte("{\"type\": \"" + twirp.Internal + "\", \"msg\": \"There was an error but it could not be serialized into JSON\"}") // fallback - } - - return buf -} - -// errorFromResponse builds a twirp.Error from a non-200 HTTP response. -// If the response has a valid serialized Twirp error, then it's returned. -// If not, the response status code is used to generate a similar twirp -// error. See twirpErrorFromIntermediary for more info on intermediary errors. -func errorFromResponse(resp *http.Response) twirp.Error { - statusCode := resp.StatusCode - statusText := http.StatusText(statusCode) - - if isHTTPRedirect(statusCode) { - // Unexpected redirect: it must be an error from an intermediary. - // Twirp clients don't follow redirects automatically, Twirp only handles - // POST requests, redirects should only happen on GET and HEAD requests. - location := resp.Header.Get("Location") - msg := fmt.Sprintf("unexpected HTTP status code %d %q received, Location=%q", statusCode, statusText, location) - return twirpErrorFromIntermediary(statusCode, msg, location) - } - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return wrapInternal(err, "failed to read server error response body") - } - - var tj twerrJSON - dec := json.NewDecoder(bytes.NewReader(respBodyBytes)) - dec.DisallowUnknownFields() - if err := dec.Decode(&tj); err != nil || tj.Code == "" { - // Invalid JSON response; it must be an error from an intermediary. - msg := fmt.Sprintf("Error from intermediary with HTTP status code %d %q", statusCode, statusText) - return twirpErrorFromIntermediary(statusCode, msg, string(respBodyBytes)) - } - - errorCode := twirp.ErrorCode(tj.Code) - if !twirp.IsValidErrorCode(errorCode) { - msg := "invalid type returned from server error response: " + tj.Code - return twirp.InternalError(msg).WithMeta("body", string(respBodyBytes)) - } - - twerr := twirp.NewError(errorCode, tj.Msg) - for k, v := range tj.Meta { - twerr = twerr.WithMeta(k, v) - } - return twerr -} - -// twirpErrorFromIntermediary maps HTTP errors from non-twirp sources to twirp errors. -// The mapping is similar to gRPC: https://github.com/grpc/grpc/blob/master/doc/http-grpc-status-mapping.md. -// Returned twirp Errors have some additional metadata for inspection. -func twirpErrorFromIntermediary(status int, msg string, bodyOrLocation string) twirp.Error { - var code twirp.ErrorCode - if isHTTPRedirect(status) { // 3xx - code = twirp.Internal - } else { - switch status { - case 400: // Bad Request - code = twirp.Internal - case 401: // Unauthorized - code = twirp.Unauthenticated - case 403: // Forbidden - code = twirp.PermissionDenied - case 404: // Not Found - code = twirp.BadRoute - case 429: // Too Many Requests - code = twirp.ResourceExhausted - case 502, 503, 504: // Bad Gateway, Service Unavailable, Gateway Timeout - code = twirp.Unavailable - default: // All other codes - code = twirp.Unknown - } - } - - twerr := twirp.NewError(code, msg) - twerr = twerr.WithMeta("http_error_from_intermediary", "true") // to easily know if this error was from intermediary - twerr = twerr.WithMeta("status_code", strconv.Itoa(status)) - if isHTTPRedirect(status) { - twerr = twerr.WithMeta("location", bodyOrLocation) - } else { - twerr = twerr.WithMeta("body", bodyOrLocation) - } - return twerr -} - -func isHTTPRedirect(status int) bool { - return status >= 300 && status <= 399 -} - -// wrapInternal wraps an error with a prefix as an Internal error. -// The original error cause is accessible by github.com/pkg/errors.Cause. -func wrapInternal(err error, prefix string) twirp.Error { - return twirp.InternalErrorWith(&wrappedError{prefix: prefix, cause: err}) -} - -type wrappedError struct { - prefix string - cause error -} - -func (e *wrappedError) Error() string { return e.prefix + ": " + e.cause.Error() } -func (e *wrappedError) Unwrap() error { return e.cause } // for go1.13 + errors.Is/As -func (e *wrappedError) Cause() error { return e.cause } // for github.com/pkg/errors - -// ensurePanicResponses makes sure that rpc methods causing a panic still result in a Twirp Internal -// error response (status 500), and error hooks are properly called with the panic wrapped as an error. -// The panic is re-raised so it can be handled normally with middleware. -func ensurePanicResponses(ctx context.Context, resp http.ResponseWriter, hooks *twirp.ServerHooks) { - if r := recover(); r != nil { - // Wrap the panic as an error so it can be passed to error hooks. - // The original error is accessible from error hooks, but not visible in the response. - err := errFromPanic(r) - twerr := &internalWithCause{msg: "Internal service panic", cause: err} - // Actually write the error - writeError(ctx, resp, twerr, hooks) - // If possible, flush the error to the wire. - f, ok := resp.(http.Flusher) - if ok { - f.Flush() - } - - panic(r) - } -} - -// errFromPanic returns the typed error if the recovered panic is an error, otherwise formats as error. -func errFromPanic(p interface{}) error { - if err, ok := p.(error); ok { - return err - } - return fmt.Errorf("panic: %v", p) -} - -// internalWithCause is a Twirp Internal error wrapping an original error cause, -// but the original error message is not exposed on Msg(). The original error -// can be checked with go1.13+ errors.Is/As, and also by (github.com/pkg/errors).Unwrap -type internalWithCause struct { - msg string - cause error -} - -func (e *internalWithCause) Unwrap() error { return e.cause } // for go1.13 + errors.Is/As -func (e *internalWithCause) Cause() error { return e.cause } // for github.com/pkg/errors -func (e *internalWithCause) Error() string { return e.msg + ": " + e.cause.Error() } -func (e *internalWithCause) Code() twirp.ErrorCode { return twirp.Internal } -func (e *internalWithCause) Msg() string { return e.msg } -func (e *internalWithCause) Meta(key string) string { return "" } -func (e *internalWithCause) MetaMap() map[string]string { return nil } -func (e *internalWithCause) WithMeta(key string, val string) twirp.Error { return e } - -// malformedRequestError is used when the twirp server cannot unmarshal a request -func malformedRequestError(msg string) twirp.Error { - return twirp.NewError(twirp.Malformed, msg) -} - -// badRouteError is used when the twirp server cannot route a request -func badRouteError(msg string, method, url string) twirp.Error { - err := twirp.NewError(twirp.BadRoute, msg) - err = err.WithMeta("twirp_invalid_route", method+" "+url) - return err -} - -// withoutRedirects makes sure that the POST request can not be redirected. -// The standard library will, by default, redirect requests (including POSTs) if it gets a 302 or -// 303 response, and also 301s in go1.8. It redirects by making a second request, changing the -// method to GET and removing the body. This produces very confusing error messages, so instead we -// set a redirect policy that always errors. This stops Go from executing the redirect. -// -// We have to be a little careful in case the user-provided http.Client has its own CheckRedirect -// policy - if so, we'll run through that policy first. -// -// Because this requires modifying the http.Client, we make a new copy of the client and return it. -func withoutRedirects(in *http.Client) *http.Client { - copy := *in - copy.CheckRedirect = func(req *http.Request, via []*http.Request) error { - if in.CheckRedirect != nil { - // Run the input's redirect if it exists, in case it has side effects, but ignore any error it - // returns, since we want to use ErrUseLastResponse. - err := in.CheckRedirect(req, via) - _ = err // Silly, but this makes sure generated code passes errcheck -blank, which some people use. - } - return http.ErrUseLastResponse - } - return © -} - -// doProtobufRequest makes a Protobuf request to the remote Twirp service. -func doProtobufRequest(ctx context.Context, client HTTPClient, hooks *twirp.ClientHooks, url string, in, out proto.Message) (_ context.Context, err error) { - reqBodyBytes, err := proto.Marshal(in) - if err != nil { - return ctx, wrapInternal(err, "failed to marshal proto request") - } - reqBody := bytes.NewBuffer(reqBodyBytes) - if err = ctx.Err(); err != nil { - return ctx, wrapInternal(err, "aborted because context was done") - } - - req, err := newRequest(ctx, url, reqBody, "application/protobuf") - if err != nil { - return ctx, wrapInternal(err, "could not build request") - } - ctx, err = callClientRequestPrepared(ctx, hooks, req) - if err != nil { - return ctx, err - } - - req = req.WithContext(ctx) - resp, err := client.Do(req) - if err != nil { - return ctx, wrapInternal(err, "failed to do request") - } - defer func() { _ = resp.Body.Close() }() - - if err = ctx.Err(); err != nil { - return ctx, wrapInternal(err, "aborted because context was done") - } - - if resp.StatusCode != 200 { - return ctx, errorFromResponse(resp) - } - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return ctx, wrapInternal(err, "failed to read response body") - } - if err = ctx.Err(); err != nil { - return ctx, wrapInternal(err, "aborted because context was done") - } - - if err = proto.Unmarshal(respBodyBytes, out); err != nil { - return ctx, wrapInternal(err, "failed to unmarshal proto response") - } - return ctx, nil -} - -// doJSONRequest makes a JSON request to the remote Twirp service. -func doJSONRequest(ctx context.Context, client HTTPClient, hooks *twirp.ClientHooks, url string, in, out proto.Message) (_ context.Context, err error) { - marshaler := &protojson.MarshalOptions{UseProtoNames: true} - reqBytes, err := marshaler.Marshal(in) - if err != nil { - return ctx, wrapInternal(err, "failed to marshal json request") - } - if err = ctx.Err(); err != nil { - return ctx, wrapInternal(err, "aborted because context was done") - } - - req, err := newRequest(ctx, url, bytes.NewReader(reqBytes), "application/json") - if err != nil { - return ctx, wrapInternal(err, "could not build request") - } - ctx, err = callClientRequestPrepared(ctx, hooks, req) - if err != nil { - return ctx, err - } - - req = req.WithContext(ctx) - resp, err := client.Do(req) - if err != nil { - return ctx, wrapInternal(err, "failed to do request") - } - - defer func() { - cerr := resp.Body.Close() - if err == nil && cerr != nil { - err = wrapInternal(cerr, "failed to close response body") - } - }() - - if err = ctx.Err(); err != nil { - return ctx, wrapInternal(err, "aborted because context was done") - } - - if resp.StatusCode != 200 { - return ctx, errorFromResponse(resp) - } - - d := json.NewDecoder(resp.Body) - rawRespBody := json.RawMessage{} - if err := d.Decode(&rawRespBody); err != nil { - return ctx, wrapInternal(err, "failed to unmarshal json response") - } - unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} - if err = unmarshaler.Unmarshal(rawRespBody, out); err != nil { - return ctx, wrapInternal(err, "failed to unmarshal json response") - } - if err = ctx.Err(); err != nil { - return ctx, wrapInternal(err, "aborted because context was done") - } - return ctx, nil -} - -// Call twirp.ServerHooks.RequestReceived if the hook is available -func callRequestReceived(ctx context.Context, h *twirp.ServerHooks) (context.Context, error) { - if h == nil || h.RequestReceived == nil { - return ctx, nil - } - return h.RequestReceived(ctx) -} - -// Call twirp.ServerHooks.RequestRouted if the hook is available -func callRequestRouted(ctx context.Context, h *twirp.ServerHooks) (context.Context, error) { - if h == nil || h.RequestRouted == nil { - return ctx, nil - } - return h.RequestRouted(ctx) -} - -// Call twirp.ServerHooks.ResponsePrepared if the hook is available -func callResponsePrepared(ctx context.Context, h *twirp.ServerHooks) context.Context { - if h == nil || h.ResponsePrepared == nil { - return ctx - } - return h.ResponsePrepared(ctx) -} - -// Call twirp.ServerHooks.ResponseSent if the hook is available -func callResponseSent(ctx context.Context, h *twirp.ServerHooks) { - if h == nil || h.ResponseSent == nil { - return - } - h.ResponseSent(ctx) -} - -// Call twirp.ServerHooks.Error if the hook is available -func callError(ctx context.Context, h *twirp.ServerHooks, err twirp.Error) context.Context { - if h == nil || h.Error == nil { - return ctx - } - return h.Error(ctx, err) -} - -func callClientResponseReceived(ctx context.Context, h *twirp.ClientHooks) { - if h == nil || h.ResponseReceived == nil { - return - } - h.ResponseReceived(ctx) -} - -func callClientRequestPrepared(ctx context.Context, h *twirp.ClientHooks, req *http.Request) (context.Context, error) { - if h == nil || h.RequestPrepared == nil { - return ctx, nil - } - return h.RequestPrepared(ctx, req) -} - -func callClientError(ctx context.Context, h *twirp.ClientHooks, err twirp.Error) { - if h == nil || h.Error == nil { - return - } - h.Error(ctx, err) -} - -var twirpFileDescriptor0 = []byte{ - // 216 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x64, 0x90, 0xbd, 0x4e, 0xc3, 0x30, - 0x14, 0x46, 0x29, 0x48, 0x4d, 0x7b, 0x05, 0x8b, 0x25, 0xc0, 0xea, 0xc0, 0x4f, 0xc4, 0xc0, 0x64, - 0x4b, 0x30, 0x32, 0x95, 0x81, 0x07, 0xc8, 0xc8, 0x66, 0xbb, 0x57, 0xad, 0x89, 0x83, 0x2f, 0xfe, - 0x11, 0xcd, 0xdb, 0x23, 0x39, 0x21, 0x8a, 0xd4, 0xf1, 0xf8, 0xc8, 0xfa, 0x8e, 0x2e, 0x5c, 0x29, - 0x6d, 0x9d, 0x4d, 0xbd, 0xa0, 0xe0, 0x93, 0x67, 0x6b, 0x95, 0xd3, 0xe1, 0x28, 0x02, 0x99, 0x7a, - 0x07, 0x97, 0x5b, 0xe7, 0xfc, 0x6f, 0x83, 0x3f, 0x19, 0x63, 0x62, 0x1c, 0xaa, 0x98, 0xf5, 0x17, - 0x9a, 0xc4, 0x17, 0x0f, 0x8b, 0xe7, 0x75, 0xf3, 0x8f, 0xec, 0x0e, 0x80, 0x30, 0x74, 0x36, 0x46, - 0xeb, 0xbf, 0xf9, 0x79, 0x91, 0xb3, 0x17, 0xb6, 0x81, 0x55, 0xc0, 0xe8, 0x73, 0x30, 0xc8, 0x2f, - 0x8a, 0x9d, 0xb8, 0x7e, 0x02, 0x18, 0x57, 0xc8, 0xf5, 0xec, 0x06, 0x96, 0x01, 0x63, 0x76, 0xc3, - 0xc4, 0xaa, 0x19, 0xe9, 0xe5, 0x03, 0xaa, 0xed, 0xd0, 0xc9, 0xde, 0xa0, 0x2a, 0x1f, 0x70, 0xc7, - 0x6e, 0xc5, 0x54, 0x2b, 0xe6, 0xa9, 0x9b, 0xeb, 0x53, 0x41, 0xae, 0xaf, 0xcf, 0xde, 0x1f, 0x3f, - 0xef, 0xf7, 0x36, 0x39, 0xa5, 0x85, 0xf1, 0x9d, 0xec, 0x7c, 0x7b, 0x50, 0x47, 0x19, 0xc9, 0xb6, - 0x28, 0xa9, 0xdd, 0xcb, 0x40, 0x46, 0x2f, 0xcb, 0x21, 0x5e, 0xff, 0x02, 0x00, 0x00, 0xff, 0xff, - 0xe2, 0x96, 0x42, 0xb1, 0x19, 0x01, 0x00, 0x00, -} diff --git a/pkg/rpc/server.go b/pkg/rpc/server.go index 08246b5b..a37df9fc 100644 --- a/pkg/rpc/server.go +++ b/pkg/rpc/server.go @@ -1,11 +1,21 @@ package rpc import ( - grpc "google.golang.org/grpc" + fmt "fmt" + http "net/http" ) -func New(options ...grpc.ServerOption) *grpc.Server { - server := grpc.NewServer(options...) - RegisterAbilityServer(server, NewAbilityService()) - return server +func New() http.Handler { + mux := http.NewServeMux() + for _, handler := range handlers() { + fmt.Printf("Registering : %v\n", handler.PathPrefix()) + mux.Handle(handler.PathPrefix(), handler) + } + return mux +} + +func handlers() []TwirpServer { + return []TwirpServer{ + NewAbilityServer(NewAbilityService()), + } } diff --git a/pkg/rpc/server_test.go b/pkg/rpc/server_test.go index da60f86a..fd6e6237 100644 --- a/pkg/rpc/server_test.go +++ b/pkg/rpc/server_test.go @@ -1,35 +1,19 @@ package rpc import ( - "net" + http "net/http" + "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - grpc "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" ) func TestServer(t *testing.T) { - listener, err := net.Listen("tcp", "localhost:0") - require.NoError(t, err) - defer listener.Close() + srv := httptest.NewServer(New()) + defer srv.Close() - server := New() - defer server.Stop() - - go func() { - require.NoError(t, server.Serve(listener)) - }() - - connection, err := grpc.NewClient( - listener.Addr().String(), - grpc.WithTransportCredentials(insecure.NewCredentials()), - ) - require.NoError(t, err) - - defer connection.Close() - client := NewAbilityClient(connection) + client := NewAbilityProtobufClient(srv.URL, &http.Client{}) t.Run("forbids", func(t *testing.T) { reply, err := client.Allowed(t.Context(), &AllowRequest{ diff --git a/pkg/srv/srv.go b/pkg/srv/srv.go deleted file mode 100644 index e7189406..00000000 --- a/pkg/srv/srv.go +++ /dev/null @@ -1,26 +0,0 @@ -package srv - -import ( - "log" - "net/http" - "time" - - "gitlab.com/mokhax/spike/pkg/cfg" -) - -func New(c *cfg.Config) *http.Server { - return &http.Server{ - Addr: c.BindAddress, - Handler: c.Mux, - TLSConfig: c.TLS, - ReadHeaderTimeout: 10 * time.Second, - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 30 * time.Second, - ErrorLog: log.Default(), - } -} - -func Run(c *cfg.Config) error { - return c.Run(New(c)) -} diff --git a/pkg/test/test.go b/pkg/test/test.go deleted file mode 100644 index 9963323a..00000000 --- a/pkg/test/test.go +++ /dev/null @@ -1,49 +0,0 @@ -package test - -import ( - "context" - "io" - "net/http" - "net/http/httptest" -) - -type RequestOption func(*http.Request) *http.Request - -func Request(method, target string, options ...RequestOption) *http.Request { - request := httptest.NewRequest(method, target, nil) - for _, option := range options { - request = option(request) - } - return request -} - -func RequestResponse(method, target string, options ...RequestOption) (*http.Request, *httptest.ResponseRecorder) { - return Request(method, target, options...), httptest.NewRecorder() -} - -func WithRequestHeader(key, value string) RequestOption { - return func(r *http.Request) *http.Request { - r.Header.Set(key, value) - return r - } -} - -func WithRequestBody(body io.ReadCloser) RequestOption { - return func(r *http.Request) *http.Request { - r.Body = body - return r - } -} - -func WithContext(ctx context.Context) RequestOption { - return func(r *http.Request) *http.Request { - return r.WithContext(ctx) - } -} - -func WithCookie(cookie *http.Cookie) RequestOption { - return func(r *http.Request) *http.Request { - r.AddCookie(cookie) - return r - } -} -- cgit v1.2.3