diff options
| -rw-r--r-- | pkg/middleware/jwt.go | 44 |
1 files changed, 36 insertions, 8 deletions
diff --git a/pkg/middleware/jwt.go b/pkg/middleware/jwt.go index 947c3ec..db04fc6 100644 --- a/pkg/middleware/jwt.go +++ b/pkg/middleware/jwt.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "errors" "fmt" "log" "net/http" @@ -33,6 +34,30 @@ func (c CustomClaims) HasScope(expectedScope string) bool { return false } +// type TokenExtractor func(r *http.Request) (string, error) +func Extractor(r *http.Request) (string, error) { + authHeader := r.Header.Get("Authorization") + fmt.Printf("%v %v\nAuthorization: %v\n", r.Method, r.URL.Path, authHeader) + if authHeader == "" { + return "", nil + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + rawToken := authHeaderParts[1] + sections := strings.Split(authHeaderParts[1], ".") + if len(sections) != 3 { + fmt.Printf("sections: %v\n", len(sections)) + return "", errors.New("Token is not a JWT") + } + + return rawToken, nil + return "", nil +} + func EnsureValidToken() func(next http.Handler) http.Handler { issuerURL, err := url.Parse("https://" + os.Getenv("AUTH0_DOMAIN") + "/") if err != nil { @@ -58,13 +83,6 @@ func EnsureValidToken() func(next http.Handler) http.Handler { errorHandler := func(w http.ResponseWriter, r *http.Request, err error) { fmt.Printf("Error: %v\n", err) - if r.Method == "OPTIONS" { - w.Header().Set("Access-Control-Allow-Credentials", "true") - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Headers", "Authorization") - w.WriteHeader(http.StatusOK) - return - } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusUnauthorized) @@ -74,9 +92,19 @@ func EnsureValidToken() func(next http.Handler) http.Handler { middleware := jwtmiddleware.New( jwtValidator.ValidateToken, jwtmiddleware.WithErrorHandler(errorHandler), + jwtmiddleware.WithTokenExtractor(Extractor), ) return func(next http.Handler) http.Handler { - return middleware.CheckJWT(next) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Credentials", "true") + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Headers", "Authorization") + w.WriteHeader(http.StatusOK) + } else { + middleware.CheckJWT(next).ServeHTTP(w, r) + } + }) } } |
