summaryrefslogtreecommitdiff
path: root/vendor/github.com/oauth2-proxy/mockoidc/handlers.go
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-05-11 21:12:57 -0600
committermo khan <mo@mokhan.ca>2025-05-11 21:12:57 -0600
commit60440f90dca28e99a31dd328c5f6d5dc0f9b6a2e (patch)
tree2f54adf55086516f162f0a55a5347e6b25f7f176 /vendor/github.com/oauth2-proxy/mockoidc/handlers.go
parent05ca9b8d3a9c7203a3a3b590beaa400900bd9007 (diff)
chore: vendor go dependencies
Diffstat (limited to 'vendor/github.com/oauth2-proxy/mockoidc/handlers.go')
-rw-r--r--vendor/github.com/oauth2-proxy/mockoidc/handlers.go527
1 files changed, 527 insertions, 0 deletions
diff --git a/vendor/github.com/oauth2-proxy/mockoidc/handlers.go b/vendor/github.com/oauth2-proxy/mockoidc/handlers.go
new file mode 100644
index 0000000..d1405f1
--- /dev/null
+++ b/vendor/github.com/oauth2-proxy/mockoidc/handlers.go
@@ -0,0 +1,527 @@
+package mockoidc
+
+import (
+ "crypto/subtle"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/url"
+ "strings"
+ "time"
+
+ "github.com/golang-jwt/jwt/v5"
+)
+
+const (
+ IssuerBase = "/oidc"
+ AuthorizationEndpoint = "/oidc/authorize"
+ TokenEndpoint = "/oidc/token"
+ UserinfoEndpoint = "/oidc/userinfo"
+ JWKSEndpoint = "/oidc/.well-known/jwks.json"
+ DiscoveryEndpoint = "/oidc/.well-known/openid-configuration"
+
+ InvalidRequest = "invalid_request"
+ InvalidClient = "invalid_client"
+ InvalidGrant = "invalid_grant"
+ UnsupportedGrantType = "unsupported_grant_type"
+ InvalidScope = "invalid_scope"
+ //UnauthorizedClient = "unauthorized_client"
+ InternalServerError = "internal_server_error"
+
+ applicationJSON = "application/json"
+ openidScope = "openid"
+)
+
+var (
+ GrantTypesSupported = []string{
+ "authorization_code",
+ "refresh_token",
+ }
+ ResponseTypesSupported = []string{
+ "code",
+ }
+ SubjectTypesSupported = []string{
+ "public",
+ }
+ IDTokenSigningAlgValuesSupported = []string{
+ "RS256",
+ }
+ ScopesSupported = []string{
+ "openid",
+ "email",
+ "groups",
+ "profile",
+ }
+ TokenEndpointAuthMethodsSupported = []string{
+ "client_secret_basic",
+ "client_secret_post",
+ }
+ ClaimsSupported = []string{
+ "sub",
+ "email",
+ "email_verified",
+ "preferred_username",
+ "phone_number",
+ "address",
+ "groups",
+ "iss",
+ "aud",
+ }
+)
+
+// Authorize implements the `authorization_endpoint` in the OIDC flow.
+// It is the initial request that "authenticates" a user in the OAuth2
+// flow and redirects the client to the application `redirect_uri`.
+func (m *MockOIDC) Authorize(rw http.ResponseWriter, req *http.Request) {
+ err := req.ParseForm()
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+
+ valid := assertPresence(
+ []string{"scope", "state", "client_id", "response_type", "redirect_uri"}, rw, req)
+ if !valid {
+ return
+ }
+
+ if !validateScope(rw, req) {
+ return
+ }
+ validClient := assertEqual("client_id", m.ClientID,
+ InvalidClient, "Invalid client id", rw, req)
+ if !validClient {
+ return
+ }
+ validType := assertEqual("response_type", "code",
+ UnsupportedGrantType, "Invalid response type", rw, req)
+ if !validType {
+ return
+ }
+ if !validateCodeChallengeMethodSupported(rw, req.Form.Get("code_challenge_method"), m.CodeChallengeMethodsSupported) {
+ return
+ }
+
+ session, err := m.SessionStore.NewSession(
+ req.Form.Get("scope"),
+ req.Form.Get("nonce"),
+ m.UserQueue.Pop(),
+ req.Form.Get("code_challenge"),
+ req.Form.Get("code_challenge_method"),
+ )
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+
+ redirectURI, err := url.Parse(req.Form.Get("redirect_uri"))
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+ params, _ := url.ParseQuery(redirectURI.RawQuery)
+ params.Set("code", session.SessionID)
+ params.Set("state", req.Form.Get("state"))
+ redirectURI.RawQuery = params.Encode()
+
+ http.Redirect(rw, req, redirectURI.String(), http.StatusFound)
+}
+
+type tokenResponse struct {
+ AccessToken string `json:"access_token,omitempty"`
+ RefreshToken string `json:"refresh_token,omitempty"`
+ IDToken string `json:"id_token,omitempty"`
+ TokenType string `json:"token_type"`
+ ExpiresIn time.Duration `json:"expires_in"`
+}
+
+// Token implements the `token_endpoint` in OIDC and responds to requests
+// from the application servers that contain the client ID & Secret along
+// with the code from the `authorization_endpoint`. It returns the various
+// OAuth tokens to the application server for the User authenticated by the
+// during the `authorization_endpoint` request (persisted across requests via
+// the `code`).
+// Reference: https://www.oauth.com/oauth2-servers/access-tokens/access-token-response/
+func (m *MockOIDC) Token(rw http.ResponseWriter, req *http.Request) {
+ err := req.ParseForm()
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+
+ if !m.validateTokenParams(rw, req) {
+ return
+ }
+
+ var (
+ session *Session
+ valid bool
+ )
+ grantType := req.Form.Get("grant_type")
+ switch grantType {
+ case "authorization_code":
+ if session, valid = m.validateCodeGrant(rw, req); !valid {
+ return
+ }
+
+ if !m.validateCodeChallenge(rw, req, session) {
+ return
+ }
+ case "refresh_token":
+ if session, valid = m.validateRefreshGrant(rw, req); !valid {
+ return
+ }
+ default:
+ errorResponse(rw, InvalidRequest,
+ fmt.Sprintf("Invalid grant type: %s", grantType), http.StatusBadRequest)
+ return
+ }
+
+ tr := &tokenResponse{
+ RefreshToken: req.Form.Get("refresh_token"),
+ TokenType: "bearer",
+ ExpiresIn: m.AccessTTL,
+ }
+ err = m.setTokens(tr, session, grantType)
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+
+ resp, err := json.Marshal(tr)
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+ noCache(rw)
+ jsonResponse(rw, resp)
+}
+
+func (m *MockOIDC) validateTokenParams(rw http.ResponseWriter, req *http.Request) bool {
+ if !assertPresence([]string{"client_id", "client_secret", "grant_type"}, rw, req) {
+ return false
+ }
+
+ equal := assertEqual("client_id", m.ClientID,
+ InvalidClient, "Invalid client id", rw, req)
+ if !equal {
+ return false
+ }
+ equal = assertEqual("client_secret", m.ClientSecret,
+ InvalidClient, "Invalid client secret", rw, req)
+ if !equal {
+ return false
+ }
+
+ return true
+}
+
+func (m *MockOIDC) validateCodeGrant(rw http.ResponseWriter, req *http.Request) (*Session, bool) {
+ if !assertPresence([]string{"code"}, rw, req) {
+ return nil, false
+ }
+ equal := assertEqual("grant_type", "authorization_code",
+ UnsupportedGrantType, "Invalid grant type", rw, req)
+ if !equal {
+ return nil, false
+ }
+
+ code := req.Form.Get("code")
+ session, err := m.SessionStore.GetSessionByID(code)
+ if err != nil || session.Granted {
+ errorResponse(rw, InvalidGrant, fmt.Sprintf("Invalid code: %s", code),
+ http.StatusUnauthorized)
+ return nil, false
+ }
+ session.Granted = true
+
+ return session, true
+}
+
+func (m *MockOIDC) validateCodeChallenge(rw http.ResponseWriter, req *http.Request, session *Session) bool {
+ if session.CodeChallenge == "" || session.CodeChallengeMethod == "" {
+ return true
+ }
+
+ codeVerifier := req.Form.Get("code_verifier")
+ if codeVerifier == "" {
+ errorResponse(rw, InvalidGrant, "Invalid code verifier. Expected code but client sent none.", http.StatusUnauthorized)
+ return false
+ }
+
+ challenge, err := GenerateCodeChallenge(session.CodeChallengeMethod, codeVerifier)
+ if err != nil {
+ errorResponse(rw, InvalidRequest, fmt.Sprintf("Invalid code verifier. %v", err.Error()), http.StatusUnauthorized)
+ return false
+ }
+
+ if challenge != session.CodeChallenge {
+ errorResponse(rw, InvalidGrant, "Invalid code verifier. Code challenge did not match hashed code verifier.", http.StatusUnauthorized)
+ return false
+ }
+
+ return true
+}
+
+func (m *MockOIDC) validateRefreshGrant(rw http.ResponseWriter, req *http.Request) (*Session, bool) {
+ if !assertPresence([]string{"refresh_token"}, rw, req) {
+ return nil, false
+ }
+
+ equal := assertEqual("grant_type", "refresh_token",
+ UnsupportedGrantType, "Invalid grant type", rw, req)
+ if !equal {
+ return nil, false
+ }
+
+ refreshToken := req.Form.Get("refresh_token")
+ token, authorized := m.authorizeToken(refreshToken, rw)
+ if !authorized {
+ return nil, false
+ }
+
+ session, err := m.SessionStore.GetSessionByToken(token)
+ if err != nil {
+ errorResponse(rw, InvalidGrant, "Invalid refresh token",
+ http.StatusUnauthorized)
+ return nil, false
+ }
+ return session, true
+}
+
+func (m *MockOIDC) setTokens(tr *tokenResponse, s *Session, grantType string) error {
+ var err error
+ tr.AccessToken, err = s.AccessToken(m.Config(), m.Keypair, m.Now())
+ if err != nil {
+ return err
+ }
+ if len(s.Scopes) > 0 && s.Scopes[0] == openidScope {
+ tr.IDToken, err = s.IDToken(m.Config(), m.Keypair, m.Now())
+ if err != nil {
+ return err
+ }
+ }
+ if grantType != "refresh_token" {
+ tr.RefreshToken, err = s.RefreshToken(m.Config(), m.Keypair, m.Now())
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+// Userinfo returns the User details for the User associated with the passed
+// Access Token. Data is scoped down to the session's access scope set in the
+// initial `authorization_endpoint` call.
+func (m *MockOIDC) Userinfo(rw http.ResponseWriter, req *http.Request) {
+ token, authorized := m.authorizeBearer(rw, req)
+ if !authorized {
+ return
+ }
+
+ session, err := m.SessionStore.GetSessionByToken(token)
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+
+ resp, err := session.User.Userinfo(session.Scopes)
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+ jsonResponse(rw, resp)
+}
+
+type discoveryResponse struct {
+ Issuer string `json:"issuer"`
+ AuthorizationEndpoint string `json:"authorization_endpoint"`
+ TokenEndpoint string `json:"token_endpoint"`
+ JWKSUri string `json:"jwks_uri"`
+ UserinfoEndpoint string `json:"userinfo_endpoint"`
+
+ GrantTypesSupported []string `json:"grant_types_supported"`
+ ResponseTypesSupported []string `json:"response_types_supported"`
+ SubjectTypesSupported []string `json:"subject_types_supported"`
+ IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"`
+ ScopesSupported []string `json:"scopes_supported"`
+ TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"`
+ ClaimsSupported []string `json:"claims_supported"`
+ CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"`
+}
+
+// Discovery renders the OIDC discovery document and partial RFC-8414 authorization
+// server metadata hosted at `/.well-known/openid-configuration`.
+func (m *MockOIDC) Discovery(rw http.ResponseWriter, _ *http.Request) {
+ discovery := &discoveryResponse{
+ Issuer: m.Issuer(),
+ AuthorizationEndpoint: m.AuthorizationEndpoint(),
+ TokenEndpoint: m.TokenEndpoint(),
+ JWKSUri: m.JWKSEndpoint(),
+ UserinfoEndpoint: m.UserinfoEndpoint(),
+
+ GrantTypesSupported: GrantTypesSupported,
+ ResponseTypesSupported: ResponseTypesSupported,
+ SubjectTypesSupported: SubjectTypesSupported,
+ IDTokenSigningAlgValuesSupported: IDTokenSigningAlgValuesSupported,
+ ScopesSupported: ScopesSupported,
+ TokenEndpointAuthMethodsSupported: TokenEndpointAuthMethodsSupported,
+ ClaimsSupported: ClaimsSupported,
+ CodeChallengeMethodsSupported: m.CodeChallengeMethodsSupported,
+ }
+
+ resp, err := json.Marshal(discovery)
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+ jsonResponse(rw, resp)
+}
+
+// JWKS returns the public key in JWKS format to verify in tokens
+// signed with our Keypair.PrivateKey.
+func (m *MockOIDC) JWKS(rw http.ResponseWriter, _ *http.Request) {
+ jwks, err := m.Keypair.JWKS()
+ if err != nil {
+ internalServerError(rw, err.Error())
+ return
+ }
+
+ jsonResponse(rw, jwks)
+}
+
+func (m *MockOIDC) authorizeBearer(rw http.ResponseWriter, req *http.Request) (*jwt.Token, bool) {
+ header := req.Header.Get("Authorization")
+ parts := strings.SplitN(header, " ", 2)
+ if len(parts) < 2 || parts[0] != "Bearer" {
+ errorResponse(rw, InvalidRequest, "Invalid authorization header",
+ http.StatusUnauthorized)
+ return nil, false
+ }
+
+ return m.authorizeToken(parts[1], rw)
+}
+
+func (m *MockOIDC) authorizeToken(t string, rw http.ResponseWriter) (*jwt.Token, bool) {
+ token, err := m.Keypair.VerifyJWT(t, m.Now)
+ if err != nil {
+ errorResponse(rw, InvalidRequest, fmt.Sprintf("Invalid token: %v", err), http.StatusUnauthorized)
+ return nil, false
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok {
+ internalServerError(rw, "Unable to extract token claims")
+ return nil, false
+ }
+ exp, ok := claims["exp"].(float64)
+ if !ok {
+ internalServerError(rw, "Unable to extract token expiration")
+ return nil, false
+ }
+ if m.Now().Unix() > int64(exp) {
+ errorResponse(rw, InvalidRequest, "The token is expired", http.StatusUnauthorized)
+ return nil, false
+ }
+ return token, true
+}
+
+func assertPresence(params []string, rw http.ResponseWriter, req *http.Request) bool {
+ for _, param := range params {
+ if req.Form.Get(param) != "" {
+ continue
+ }
+ errorResponse(
+ rw,
+ InvalidRequest,
+ fmt.Sprintf("The request is missing the required parameter: %s", param),
+ http.StatusBadRequest,
+ )
+ return false
+ }
+ return true
+}
+
+func assertEqual(param, value, errorType, errorMsg string, rw http.ResponseWriter, req *http.Request) bool {
+ formValue := req.Form.Get(param)
+ if subtle.ConstantTimeCompare([]byte(value), []byte(formValue)) == 0 {
+ errorResponse(rw, errorType, fmt.Sprintf("%s: %s", errorMsg, formValue),
+ http.StatusUnauthorized)
+ return false
+ }
+ return true
+}
+
+func validateScope(rw http.ResponseWriter, req *http.Request) bool {
+ allowed := make(map[string]struct{})
+ for _, scope := range ScopesSupported {
+ allowed[scope] = struct{}{}
+ }
+
+ scopes := strings.Split(req.Form.Get("scope"), " ")
+ for _, scope := range scopes {
+ if _, ok := allowed[scope]; !ok {
+ errorResponse(rw, InvalidScope, fmt.Sprintf("Unsupported scope: %s", scope),
+ http.StatusBadRequest)
+ return false
+ }
+ }
+ return true
+}
+
+func validateCodeChallengeMethodSupported(rw http.ResponseWriter, method string, supportedMethods []string) bool {
+ if method != "" && !contains(method, supportedMethods) {
+ errorResponse(rw, InvalidRequest, "Invalid code challenge method", http.StatusBadRequest)
+ return false
+ }
+ return true
+}
+
+func errorResponse(rw http.ResponseWriter, error, description string, statusCode int) {
+ errJSON := map[string]string{
+ "error": error,
+ "error_description": description,
+ }
+ resp, err := json.Marshal(errJSON)
+ if err != nil {
+ http.Error(rw, error, http.StatusInternalServerError)
+ }
+
+ noCache(rw)
+ rw.Header().Set("Content-Type", applicationJSON)
+ rw.WriteHeader(statusCode)
+
+ _, err = rw.Write(resp)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func internalServerError(rw http.ResponseWriter, errorMsg string) {
+ errorResponse(rw, InternalServerError, errorMsg, http.StatusInternalServerError)
+}
+
+func jsonResponse(rw http.ResponseWriter, data []byte) {
+ noCache(rw)
+ rw.Header().Set("Content-Type", applicationJSON)
+ rw.WriteHeader(http.StatusOK)
+
+ _, err := rw.Write(data)
+ if err != nil {
+ panic(err)
+ }
+}
+
+func noCache(rw http.ResponseWriter) {
+ rw.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0")
+ rw.Header().Set("Pragma", "no-cache")
+}
+
+func contains(value string, list []string) bool {
+ for _, element := range list {
+ if element == value {
+ return true
+ }
+ }
+ return false
+}