summaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/authz/authz.go23
-rw-r--r--pkg/prxy/prxy.go31
-rw-r--r--pkg/prxy/prxy_test.go49
-rw-r--r--pkg/test/test.go49
4 files changed, 136 insertions, 16 deletions
diff --git a/pkg/authz/authz.go b/pkg/authz/authz.go
new file mode 100644
index 00000000..5a93a29c
--- /dev/null
+++ b/pkg/authz/authz.go
@@ -0,0 +1,23 @@
+package authz
+
+import "net/http"
+
+type Authorizer interface {
+ Authorize(*http.Request) bool
+}
+
+type AuthorizerFunc func(*http.Request) bool
+
+func (f AuthorizerFunc) Authorize(r *http.Request) bool {
+ return f(r)
+}
+
+func HTTP(authorizer Authorizer, h http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if authorizer.Authorize(r) {
+ h.ServeHTTP(w, r)
+ } else {
+ w.WriteHeader(http.StatusForbidden)
+ }
+ })
+}
diff --git a/pkg/prxy/prxy.go b/pkg/prxy/prxy.go
index 54aad00c..0e6e8c31 100644
--- a/pkg/prxy/prxy.go
+++ b/pkg/prxy/prxy.go
@@ -3,33 +3,32 @@ package prxy
import (
"fmt"
"log"
+ "net"
"net/http"
"net/http/httputil"
- "strings"
+ "net/url"
- "github.com/casbin/casbin/v2"
"github.com/xlgmokha/x/pkg/x"
)
func New(routes map[string]string) http.Handler {
- authz := x.Must(casbin.NewEnforcer("model.conf", "policy.csv"))
+ mapped := map[string]*url.URL{}
+ for source, destination := range routes {
+ mapped[source] = x.Must(url.Parse(destination))
+ }
return &httputil.ReverseProxy{
Director: func(r *http.Request) {
- segments := strings.SplitN(r.Host, ":", 2)
- host := segments[0]
- destinationHost := routes[host]
-
- log.Printf("%v (from: %v to: %v)\n", r.URL, host, destinationHost)
-
- subject := "71cbc18e-bd41-4229-9ad2-749546a2a4a7" // TODO:: unpack sub claim in JWT
- if x.Must(authz.Enforce(subject, host, r.Method, r.URL.Path)) {
- r.URL.Scheme = "http" // TODO:: use TLS
- r.Host = destinationHost
- r.URL.Host = destinationHost
- } else {
- log.Println("UNAUTHORIZED") // TODO:: Return forbidden, unauthorized or not found status code
+ host, _, err := net.SplitHostPort(r.Host)
+ if err != nil {
+ fmt.Printf("%v\n", err)
+ return
}
+
+ destination := mapped[host]
+ r.URL.Scheme = destination.Scheme
+ r.Host = destination.Host
+ r.URL.Host = destination.Host
},
Transport: http.DefaultTransport,
FlushInterval: -1,
diff --git a/pkg/prxy/prxy_test.go b/pkg/prxy/prxy_test.go
new file mode 100644
index 00000000..6f37974e
--- /dev/null
+++ b/pkg/prxy/prxy_test.go
@@ -0,0 +1,49 @@
+package prxy
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+ "github.com/xlgmokha/x/pkg/x"
+ "gitlab.com/mokhax/spike/pkg/test"
+)
+
+func TestProxy(t *testing.T) {
+ t.Run("http://idp.test", func(t *testing.T) {
+ var lastIdPRequest *http.Request
+ var lastUiRequest *http.Request
+
+ idp := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ lastIdPRequest = r
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer idp.Close()
+
+ ui := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ lastUiRequest = r
+ w.WriteHeader(http.StatusTeapot)
+ }))
+ defer ui.Close()
+
+ subject := New(map[string]string{
+ "idp.test": idp.URL,
+ "ui.test": ui.URL,
+ })
+
+ r, w := test.RequestResponse("GET", "http://idp.test:8080/saml/new")
+
+ subject.ServeHTTP(w, r)
+
+ url := x.Must(url.Parse(idp.URL))
+
+ assert.Nil(t, lastUiRequest)
+ assert.Equal(t, http.StatusOK, w.Code)
+
+ require.NotNil(t, lastIdPRequest)
+ assert.Equal(t, url.Host, lastIdPRequest.Host)
+ })
+}
diff --git a/pkg/test/test.go b/pkg/test/test.go
new file mode 100644
index 00000000..9963323a
--- /dev/null
+++ b/pkg/test/test.go
@@ -0,0 +1,49 @@
+package test
+
+import (
+ "context"
+ "io"
+ "net/http"
+ "net/http/httptest"
+)
+
+type RequestOption func(*http.Request) *http.Request
+
+func Request(method, target string, options ...RequestOption) *http.Request {
+ request := httptest.NewRequest(method, target, nil)
+ for _, option := range options {
+ request = option(request)
+ }
+ return request
+}
+
+func RequestResponse(method, target string, options ...RequestOption) (*http.Request, *httptest.ResponseRecorder) {
+ return Request(method, target, options...), httptest.NewRecorder()
+}
+
+func WithRequestHeader(key, value string) RequestOption {
+ return func(r *http.Request) *http.Request {
+ r.Header.Set(key, value)
+ return r
+ }
+}
+
+func WithRequestBody(body io.ReadCloser) RequestOption {
+ return func(r *http.Request) *http.Request {
+ r.Body = body
+ return r
+ }
+}
+
+func WithContext(ctx context.Context) RequestOption {
+ return func(r *http.Request) *http.Request {
+ return r.WithContext(ctx)
+ }
+}
+
+func WithCookie(cookie *http.Cookie) RequestOption {
+ return func(r *http.Request) *http.Request {
+ r.AddCookie(cookie)
+ return r
+ }
+}