diff options
| author | mo khan <mo@mokhan.ca> | 2025-03-13 16:43:47 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-03-13 16:43:47 -0600 |
| commit | c9f394fe7fa0a5a6504b5b80ae7019cffdf4bb14 (patch) | |
| tree | da1ef1c59264221c2c483ddd76401ee19cd1015c | |
| parent | b55a6617971fa50bb064480f78343e6c0bc59dbe (diff) | |
refactor: extract authz interface to test out different PaC libraries
| -rwxr-xr-x | bin/api | 98 | ||||
| -rw-r--r-- | cmd/gtwy/main.go | 36 | ||||
| -rw-r--r-- | magefile.go | 13 | ||||
| -rw-r--r-- | pkg/authz/authz.go | 23 | ||||
| -rw-r--r-- | pkg/prxy/prxy.go | 31 | ||||
| -rw-r--r-- | pkg/prxy/prxy_test.go | 49 | ||||
| -rw-r--r-- | pkg/test/test.go | 49 | ||||
| -rw-r--r-- | test/e2e_test.go | 48 |
8 files changed, 257 insertions, 90 deletions
@@ -27,20 +27,10 @@ $scheme = ENV.fetch("SCHEME", "http") $port = ENV.fetch("PORT", 8284).to_i $host = ENV.fetch("HOST", "localhost:#{$port}") -class Organization - def initialize(attributes = {}) - @attributes = attributes - end - - def id - @attributes[:id] - end -end - -class Project +class Entity class << self def all - @projects ||= [] + @items ||= [] end def create!(attributes) @@ -54,47 +44,32 @@ class Project @attributes = attributes end - def to_h - @attributes + def id + self[:id] end -end -class API - attr_reader :rpc + def [](attribute) + @attributes.fetch(attribute) + end - def initialize - @rpc = ::Authx::Rpc::AbilityClient.new("http://idp.example.com:8080/twirp") + def to_h + @attributes end +end - def call(env) - request = Rack::Request.new(env) - path = env['PATH_INFO'] - case env['REQUEST_METHOD'] - when 'GET' - case path - when '/projects.json' - return json_ok(Project.all.map(&:to_h)) - else - return json_not_found - end - when 'POST' - case path - when "/projects" - if authorized?(request, :create_project) - return json_created(Project.create!(JSON.parse(request.body.read, symbolize_names: true))) - else - return json_unauthorized(:create_project) - end - else - return json_not_found - end +class Organization < Entity + class << self + def default + @default ||= create!(id: SecureRandom.uuid) end - json_not_found end +end - private +class Project < Entity +end - def authorized?(request, permission, resource = Organization.new(id: 1)) +module HTTPHelpers + def authorized?(request, permission, resource) authorization = Rack::Auth::AbstractRequest.new(request.env) return false unless authorization.provided? @@ -136,6 +111,41 @@ class API end end +class API + include HTTPHelpers + + attr_reader :rpc + + def initialize + @rpc = ::Authx::Rpc::AbilityClient.new("http://idp.example.com:8080/twirp") + end + + def call(env) + request = Rack::Request.new(env) + case request.request_method + when Rack::GET + case request.path + when "/organizations", "/organizations.json" + return json_ok(Organization.all.map(&:to_h)) + when "/projects", "/projects.json" + return json_ok(Project.all.map(&:to_h)) + end + when Rack::POST + case request.path + when "/projects", "/projects.json" + if authorized?(request, :create_project, Organization.default) + return json_created(Project.create!(JSON.parse(request.body.read, symbolize_names: true))) + else + return json_unauthorized(:create_project) + end + end + end + json_not_found + end + + private +end + if __FILE__ == $0 app = Rack::Builder.new do use Rack::CommonLogger diff --git a/cmd/gtwy/main.go b/cmd/gtwy/main.go index 1e9d3a39..0da2ea88 100644 --- a/cmd/gtwy/main.go +++ b/cmd/gtwy/main.go @@ -1,23 +1,49 @@ package main import ( + "fmt" "log" + "net" "net/http" + "github.com/casbin/casbin/v2" "github.com/xlgmokha/x/pkg/env" + "github.com/xlgmokha/x/pkg/x" + "gitlab.com/mokhax/spike/pkg/authz" "gitlab.com/mokhax/spike/pkg/cfg" "gitlab.com/mokhax/spike/pkg/prxy" "gitlab.com/mokhax/spike/pkg/srv" ) +func WithCasbin() authz.Authorizer { + enforcer := x.Must(casbin.NewEnforcer("model.conf", "policy.csv")) + + return authz.AuthorizerFunc(func(r *http.Request) bool { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + return false + } + + subject := "71cbc18e-bd41-4229-9ad2-749546a2a4a7" // TODO:: unpack sub claim in JWT + ok, err := enforcer.Enforce(subject, host, r.Method, r.URL.Path) + if err != nil { + fmt.Printf("%v\n", err) + return false + } + + fmt.Printf("%v: %v %v %v\n", ok, r.Method, host, r.URL.Path) + return ok + }) +} + func WithRoutes() cfg.Option { return func(c *cfg.Config) { mux := http.NewServeMux() - mux.Handle("/", prxy.New(map[string]string{ - "idp.example.com": "localhost:8282", - "ui.example.com": "localhost:8283", - "api.example.com": "localhost:8284", - })) + mux.Handle("/", authz.HTTP(WithCasbin(), prxy.New(map[string]string{ + "idp.example.com": "http://localhost:8282", + "ui.example.com": "http://localhost:8283", + "api.example.com": "http://localhost:8284", + }))) cfg.WithMux(mux)(c) } diff --git a/magefile.go b/magefile.go index 92d071da..ec6dac14 100644 --- a/magefile.go +++ b/magefile.go @@ -6,7 +6,6 @@ package main import ( "context" "path/filepath" - "runtime" "github.com/magefile/mage/mg" "github.com/magefile/mage/sh" @@ -56,16 +55,6 @@ func Api() error { return sh.RunWithV(env, "ruby", "./bin/api") } -// Open a web browser to the login page -func Browser() error { - url := "http://localhost:8080/ui/sessions/new" - if runtime.GOOS == "linux" { - return sh.RunV("xdg-open", url) - } else { - return sh.RunV("open", url) - } -} - // Generate gRPC from protocal buffers func Protos() error { outDir := "lib/authx/rpc" @@ -94,5 +83,5 @@ func Test(ctx context.Context) error { mg.CtxDeps(ctx, func() error { return sh.RunV("go", "clean", "-testcache") }) - return sh.RunV("go", "test", "-v", "./...") + return sh.RunV("go", "test", "-shuffle=on", "-v", "./...") } diff --git a/pkg/authz/authz.go b/pkg/authz/authz.go new file mode 100644 index 00000000..5a93a29c --- /dev/null +++ b/pkg/authz/authz.go @@ -0,0 +1,23 @@ +package authz + +import "net/http" + +type Authorizer interface { + Authorize(*http.Request) bool +} + +type AuthorizerFunc func(*http.Request) bool + +func (f AuthorizerFunc) Authorize(r *http.Request) bool { + return f(r) +} + +func HTTP(authorizer Authorizer, h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if authorizer.Authorize(r) { + h.ServeHTTP(w, r) + } else { + w.WriteHeader(http.StatusForbidden) + } + }) +} diff --git a/pkg/prxy/prxy.go b/pkg/prxy/prxy.go index 54aad00c..0e6e8c31 100644 --- a/pkg/prxy/prxy.go +++ b/pkg/prxy/prxy.go @@ -3,33 +3,32 @@ package prxy import ( "fmt" "log" + "net" "net/http" "net/http/httputil" - "strings" + "net/url" - "github.com/casbin/casbin/v2" "github.com/xlgmokha/x/pkg/x" ) func New(routes map[string]string) http.Handler { - authz := x.Must(casbin.NewEnforcer("model.conf", "policy.csv")) + mapped := map[string]*url.URL{} + for source, destination := range routes { + mapped[source] = x.Must(url.Parse(destination)) + } return &httputil.ReverseProxy{ Director: func(r *http.Request) { - segments := strings.SplitN(r.Host, ":", 2) - host := segments[0] - destinationHost := routes[host] - - log.Printf("%v (from: %v to: %v)\n", r.URL, host, destinationHost) - - subject := "71cbc18e-bd41-4229-9ad2-749546a2a4a7" // TODO:: unpack sub claim in JWT - if x.Must(authz.Enforce(subject, host, r.Method, r.URL.Path)) { - r.URL.Scheme = "http" // TODO:: use TLS - r.Host = destinationHost - r.URL.Host = destinationHost - } else { - log.Println("UNAUTHORIZED") // TODO:: Return forbidden, unauthorized or not found status code + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + fmt.Printf("%v\n", err) + return } + + destination := mapped[host] + r.URL.Scheme = destination.Scheme + r.Host = destination.Host + r.URL.Host = destination.Host }, Transport: http.DefaultTransport, FlushInterval: -1, diff --git a/pkg/prxy/prxy_test.go b/pkg/prxy/prxy_test.go new file mode 100644 index 00000000..6f37974e --- /dev/null +++ b/pkg/prxy/prxy_test.go @@ -0,0 +1,49 @@ +package prxy + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/xlgmokha/x/pkg/x" + "gitlab.com/mokhax/spike/pkg/test" +) + +func TestProxy(t *testing.T) { + t.Run("http://idp.test", func(t *testing.T) { + var lastIdPRequest *http.Request + var lastUiRequest *http.Request + + idp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lastIdPRequest = r + w.WriteHeader(http.StatusOK) + })) + defer idp.Close() + + ui := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + lastUiRequest = r + w.WriteHeader(http.StatusTeapot) + })) + defer ui.Close() + + subject := New(map[string]string{ + "idp.test": idp.URL, + "ui.test": ui.URL, + }) + + r, w := test.RequestResponse("GET", "http://idp.test:8080/saml/new") + + subject.ServeHTTP(w, r) + + url := x.Must(url.Parse(idp.URL)) + + assert.Nil(t, lastUiRequest) + assert.Equal(t, http.StatusOK, w.Code) + + require.NotNil(t, lastIdPRequest) + assert.Equal(t, url.Host, lastIdPRequest.Host) + }) +} diff --git a/pkg/test/test.go b/pkg/test/test.go new file mode 100644 index 00000000..9963323a --- /dev/null +++ b/pkg/test/test.go @@ -0,0 +1,49 @@ +package test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" +) + +type RequestOption func(*http.Request) *http.Request + +func Request(method, target string, options ...RequestOption) *http.Request { + request := httptest.NewRequest(method, target, nil) + for _, option := range options { + request = option(request) + } + return request +} + +func RequestResponse(method, target string, options ...RequestOption) (*http.Request, *httptest.ResponseRecorder) { + return Request(method, target, options...), httptest.NewRecorder() +} + +func WithRequestHeader(key, value string) RequestOption { + return func(r *http.Request) *http.Request { + r.Header.Set(key, value) + return r + } +} + +func WithRequestBody(body io.ReadCloser) RequestOption { + return func(r *http.Request) *http.Request { + r.Body = body + return r + } +} + +func WithContext(ctx context.Context) RequestOption { + return func(r *http.Request) *http.Request { + return r.WithContext(ctx) + } +} + +func WithCookie(cookie *http.Cookie) RequestOption { + return func(r *http.Request) *http.Request { + r.AddCookie(cookie) + return r + } +} diff --git a/test/e2e_test.go b/test/e2e_test.go index b465d764..7e98b1bf 100644 --- a/test/e2e_test.go +++ b/test/e2e_test.go @@ -80,19 +80,41 @@ func TestAuthx(t *testing.T) { assert.Equal(t, "Bearer", item.TokenType) assert.NotEmpty(t, item.RefreshToken) - response := x.Must(http.Get("http://api.example.com:8080/projects.json")) - assert.Equal(t, http.StatusOK, response.StatusCode) - projects := x.Must(serde.FromJSON[[]map[string]string](response.Body)) - assert.NotNil(t, projects) - - io := bytes.NewBuffer(nil) - assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example"})) - request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects", io)) - request.Header.Add("Authorization", "Bearer "+item.AccessToken) - response = x.Must(client.Do(request)) - assert.Equal(t, http.StatusCreated, response.StatusCode) - project := x.Must(serde.FromJSON[map[string]string](response.Body)) - assert.Equal(t, "example", project["name"]) + t.Run("lists all the organizations", func(t *testing.T) { + response := x.Must(http.Get("http://api.example.com:8080/organizations.json")) + assert.Equal(t, http.StatusOK, response.StatusCode) + organizations := x.Must(serde.FromJSON[[]map[string]string](response.Body)) + assert.NotNil(t, organizations) + }) + + t.Run("lists all the projects", func(t *testing.T) { + response := x.Must(http.Get("http://api.example.com:8080/projects.json")) + assert.Equal(t, http.StatusOK, response.StatusCode) + projects := x.Must(serde.FromJSON[[]map[string]string](response.Body)) + assert.NotNil(t, projects) + }) + + t.Run("creates a new project", func(t *testing.T) { + io := bytes.NewBuffer(nil) + assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example"})) + request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects", io)) + request.Header.Add("Authorization", "Bearer "+item.AccessToken) + response := x.Must(client.Do(request)) + assert.Equal(t, http.StatusCreated, response.StatusCode) + project := x.Must(serde.FromJSON[map[string]string](response.Body)) + assert.Equal(t, "example", project["name"]) + }) + + t.Run("creates another projects", func(t *testing.T) { + io := bytes.NewBuffer(nil) + assert.NoError(t, serde.ToJSON(io, map[string]string{"name": "example2"})) + request := x.Must(http.NewRequestWithContext(t.Context(), "POST", "http://api.example.com:8080/projects.json", io)) + request.Header.Add("Authorization", "Bearer "+item.AccessToken) + response := x.Must(client.Do(request)) + assert.Equal(t, http.StatusCreated, response.StatusCode) + project := x.Must(serde.FromJSON[map[string]string](response.Body)) + assert.Equal(t, "example2", project["name"]) + }) }) }) |
