diff options
| author | mo khan <mo@mokhan.ca> | 2025-05-11 21:12:57 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-05-11 21:12:57 -0600 |
| commit | 60440f90dca28e99a31dd328c5f6d5dc0f9b6a2e (patch) | |
| tree | 2f54adf55086516f162f0a55a5347e6b25f7f176 /vendor/github.com/oauth2-proxy/mockoidc/handlers.go | |
| parent | 05ca9b8d3a9c7203a3a3b590beaa400900bd9007 (diff) | |
chore: vendor go dependencies
Diffstat (limited to 'vendor/github.com/oauth2-proxy/mockoidc/handlers.go')
| -rw-r--r-- | vendor/github.com/oauth2-proxy/mockoidc/handlers.go | 527 |
1 files changed, 527 insertions, 0 deletions
diff --git a/vendor/github.com/oauth2-proxy/mockoidc/handlers.go b/vendor/github.com/oauth2-proxy/mockoidc/handlers.go new file mode 100644 index 0000000..d1405f1 --- /dev/null +++ b/vendor/github.com/oauth2-proxy/mockoidc/handlers.go @@ -0,0 +1,527 @@ +package mockoidc + +import ( + "crypto/subtle" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +const ( + IssuerBase = "/oidc" + AuthorizationEndpoint = "/oidc/authorize" + TokenEndpoint = "/oidc/token" + UserinfoEndpoint = "/oidc/userinfo" + JWKSEndpoint = "/oidc/.well-known/jwks.json" + DiscoveryEndpoint = "/oidc/.well-known/openid-configuration" + + InvalidRequest = "invalid_request" + InvalidClient = "invalid_client" + InvalidGrant = "invalid_grant" + UnsupportedGrantType = "unsupported_grant_type" + InvalidScope = "invalid_scope" + //UnauthorizedClient = "unauthorized_client" + InternalServerError = "internal_server_error" + + applicationJSON = "application/json" + openidScope = "openid" +) + +var ( + GrantTypesSupported = []string{ + "authorization_code", + "refresh_token", + } + ResponseTypesSupported = []string{ + "code", + } + SubjectTypesSupported = []string{ + "public", + } + IDTokenSigningAlgValuesSupported = []string{ + "RS256", + } + ScopesSupported = []string{ + "openid", + "email", + "groups", + "profile", + } + TokenEndpointAuthMethodsSupported = []string{ + "client_secret_basic", + "client_secret_post", + } + ClaimsSupported = []string{ + "sub", + "email", + "email_verified", + "preferred_username", + "phone_number", + "address", + "groups", + "iss", + "aud", + } +) + +// Authorize implements the `authorization_endpoint` in the OIDC flow. +// It is the initial request that "authenticates" a user in the OAuth2 +// flow and redirects the client to the application `redirect_uri`. +func (m *MockOIDC) Authorize(rw http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + internalServerError(rw, err.Error()) + return + } + + valid := assertPresence( + []string{"scope", "state", "client_id", "response_type", "redirect_uri"}, rw, req) + if !valid { + return + } + + if !validateScope(rw, req) { + return + } + validClient := assertEqual("client_id", m.ClientID, + InvalidClient, "Invalid client id", rw, req) + if !validClient { + return + } + validType := assertEqual("response_type", "code", + UnsupportedGrantType, "Invalid response type", rw, req) + if !validType { + return + } + if !validateCodeChallengeMethodSupported(rw, req.Form.Get("code_challenge_method"), m.CodeChallengeMethodsSupported) { + return + } + + session, err := m.SessionStore.NewSession( + req.Form.Get("scope"), + req.Form.Get("nonce"), + m.UserQueue.Pop(), + req.Form.Get("code_challenge"), + req.Form.Get("code_challenge_method"), + ) + if err != nil { + internalServerError(rw, err.Error()) + return + } + + redirectURI, err := url.Parse(req.Form.Get("redirect_uri")) + if err != nil { + internalServerError(rw, err.Error()) + return + } + params, _ := url.ParseQuery(redirectURI.RawQuery) + params.Set("code", session.SessionID) + params.Set("state", req.Form.Get("state")) + redirectURI.RawQuery = params.Encode() + + http.Redirect(rw, req, redirectURI.String(), http.StatusFound) +} + +type tokenResponse struct { + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + IDToken string `json:"id_token,omitempty"` + TokenType string `json:"token_type"` + ExpiresIn time.Duration `json:"expires_in"` +} + +// Token implements the `token_endpoint` in OIDC and responds to requests +// from the application servers that contain the client ID & Secret along +// with the code from the `authorization_endpoint`. It returns the various +// OAuth tokens to the application server for the User authenticated by the +// during the `authorization_endpoint` request (persisted across requests via +// the `code`). +// Reference: https://www.oauth.com/oauth2-servers/access-tokens/access-token-response/ +func (m *MockOIDC) Token(rw http.ResponseWriter, req *http.Request) { + err := req.ParseForm() + if err != nil { + internalServerError(rw, err.Error()) + return + } + + if !m.validateTokenParams(rw, req) { + return + } + + var ( + session *Session + valid bool + ) + grantType := req.Form.Get("grant_type") + switch grantType { + case "authorization_code": + if session, valid = m.validateCodeGrant(rw, req); !valid { + return + } + + if !m.validateCodeChallenge(rw, req, session) { + return + } + case "refresh_token": + if session, valid = m.validateRefreshGrant(rw, req); !valid { + return + } + default: + errorResponse(rw, InvalidRequest, + fmt.Sprintf("Invalid grant type: %s", grantType), http.StatusBadRequest) + return + } + + tr := &tokenResponse{ + RefreshToken: req.Form.Get("refresh_token"), + TokenType: "bearer", + ExpiresIn: m.AccessTTL, + } + err = m.setTokens(tr, session, grantType) + if err != nil { + internalServerError(rw, err.Error()) + return + } + + resp, err := json.Marshal(tr) + if err != nil { + internalServerError(rw, err.Error()) + return + } + noCache(rw) + jsonResponse(rw, resp) +} + +func (m *MockOIDC) validateTokenParams(rw http.ResponseWriter, req *http.Request) bool { + if !assertPresence([]string{"client_id", "client_secret", "grant_type"}, rw, req) { + return false + } + + equal := assertEqual("client_id", m.ClientID, + InvalidClient, "Invalid client id", rw, req) + if !equal { + return false + } + equal = assertEqual("client_secret", m.ClientSecret, + InvalidClient, "Invalid client secret", rw, req) + if !equal { + return false + } + + return true +} + +func (m *MockOIDC) validateCodeGrant(rw http.ResponseWriter, req *http.Request) (*Session, bool) { + if !assertPresence([]string{"code"}, rw, req) { + return nil, false + } + equal := assertEqual("grant_type", "authorization_code", + UnsupportedGrantType, "Invalid grant type", rw, req) + if !equal { + return nil, false + } + + code := req.Form.Get("code") + session, err := m.SessionStore.GetSessionByID(code) + if err != nil || session.Granted { + errorResponse(rw, InvalidGrant, fmt.Sprintf("Invalid code: %s", code), + http.StatusUnauthorized) + return nil, false + } + session.Granted = true + + return session, true +} + +func (m *MockOIDC) validateCodeChallenge(rw http.ResponseWriter, req *http.Request, session *Session) bool { + if session.CodeChallenge == "" || session.CodeChallengeMethod == "" { + return true + } + + codeVerifier := req.Form.Get("code_verifier") + if codeVerifier == "" { + errorResponse(rw, InvalidGrant, "Invalid code verifier. Expected code but client sent none.", http.StatusUnauthorized) + return false + } + + challenge, err := GenerateCodeChallenge(session.CodeChallengeMethod, codeVerifier) + if err != nil { + errorResponse(rw, InvalidRequest, fmt.Sprintf("Invalid code verifier. %v", err.Error()), http.StatusUnauthorized) + return false + } + + if challenge != session.CodeChallenge { + errorResponse(rw, InvalidGrant, "Invalid code verifier. Code challenge did not match hashed code verifier.", http.StatusUnauthorized) + return false + } + + return true +} + +func (m *MockOIDC) validateRefreshGrant(rw http.ResponseWriter, req *http.Request) (*Session, bool) { + if !assertPresence([]string{"refresh_token"}, rw, req) { + return nil, false + } + + equal := assertEqual("grant_type", "refresh_token", + UnsupportedGrantType, "Invalid grant type", rw, req) + if !equal { + return nil, false + } + + refreshToken := req.Form.Get("refresh_token") + token, authorized := m.authorizeToken(refreshToken, rw) + if !authorized { + return nil, false + } + + session, err := m.SessionStore.GetSessionByToken(token) + if err != nil { + errorResponse(rw, InvalidGrant, "Invalid refresh token", + http.StatusUnauthorized) + return nil, false + } + return session, true +} + +func (m *MockOIDC) setTokens(tr *tokenResponse, s *Session, grantType string) error { + var err error + tr.AccessToken, err = s.AccessToken(m.Config(), m.Keypair, m.Now()) + if err != nil { + return err + } + if len(s.Scopes) > 0 && s.Scopes[0] == openidScope { + tr.IDToken, err = s.IDToken(m.Config(), m.Keypair, m.Now()) + if err != nil { + return err + } + } + if grantType != "refresh_token" { + tr.RefreshToken, err = s.RefreshToken(m.Config(), m.Keypair, m.Now()) + if err != nil { + return err + } + } + return nil +} + +// Userinfo returns the User details for the User associated with the passed +// Access Token. Data is scoped down to the session's access scope set in the +// initial `authorization_endpoint` call. +func (m *MockOIDC) Userinfo(rw http.ResponseWriter, req *http.Request) { + token, authorized := m.authorizeBearer(rw, req) + if !authorized { + return + } + + session, err := m.SessionStore.GetSessionByToken(token) + if err != nil { + internalServerError(rw, err.Error()) + return + } + + resp, err := session.User.Userinfo(session.Scopes) + if err != nil { + internalServerError(rw, err.Error()) + return + } + jsonResponse(rw, resp) +} + +type discoveryResponse struct { + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + JWKSUri string `json:"jwks_uri"` + UserinfoEndpoint string `json:"userinfo_endpoint"` + + GrantTypesSupported []string `json:"grant_types_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` + IDTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + ScopesSupported []string `json:"scopes_supported"` + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported"` + ClaimsSupported []string `json:"claims_supported"` + CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported"` +} + +// Discovery renders the OIDC discovery document and partial RFC-8414 authorization +// server metadata hosted at `/.well-known/openid-configuration`. +func (m *MockOIDC) Discovery(rw http.ResponseWriter, _ *http.Request) { + discovery := &discoveryResponse{ + Issuer: m.Issuer(), + AuthorizationEndpoint: m.AuthorizationEndpoint(), + TokenEndpoint: m.TokenEndpoint(), + JWKSUri: m.JWKSEndpoint(), + UserinfoEndpoint: m.UserinfoEndpoint(), + + GrantTypesSupported: GrantTypesSupported, + ResponseTypesSupported: ResponseTypesSupported, + SubjectTypesSupported: SubjectTypesSupported, + IDTokenSigningAlgValuesSupported: IDTokenSigningAlgValuesSupported, + ScopesSupported: ScopesSupported, + TokenEndpointAuthMethodsSupported: TokenEndpointAuthMethodsSupported, + ClaimsSupported: ClaimsSupported, + CodeChallengeMethodsSupported: m.CodeChallengeMethodsSupported, + } + + resp, err := json.Marshal(discovery) + if err != nil { + internalServerError(rw, err.Error()) + return + } + jsonResponse(rw, resp) +} + +// JWKS returns the public key in JWKS format to verify in tokens +// signed with our Keypair.PrivateKey. +func (m *MockOIDC) JWKS(rw http.ResponseWriter, _ *http.Request) { + jwks, err := m.Keypair.JWKS() + if err != nil { + internalServerError(rw, err.Error()) + return + } + + jsonResponse(rw, jwks) +} + +func (m *MockOIDC) authorizeBearer(rw http.ResponseWriter, req *http.Request) (*jwt.Token, bool) { + header := req.Header.Get("Authorization") + parts := strings.SplitN(header, " ", 2) + if len(parts) < 2 || parts[0] != "Bearer" { + errorResponse(rw, InvalidRequest, "Invalid authorization header", + http.StatusUnauthorized) + return nil, false + } + + return m.authorizeToken(parts[1], rw) +} + +func (m *MockOIDC) authorizeToken(t string, rw http.ResponseWriter) (*jwt.Token, bool) { + token, err := m.Keypair.VerifyJWT(t, m.Now) + if err != nil { + errorResponse(rw, InvalidRequest, fmt.Sprintf("Invalid token: %v", err), http.StatusUnauthorized) + return nil, false + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + internalServerError(rw, "Unable to extract token claims") + return nil, false + } + exp, ok := claims["exp"].(float64) + if !ok { + internalServerError(rw, "Unable to extract token expiration") + return nil, false + } + if m.Now().Unix() > int64(exp) { + errorResponse(rw, InvalidRequest, "The token is expired", http.StatusUnauthorized) + return nil, false + } + return token, true +} + +func assertPresence(params []string, rw http.ResponseWriter, req *http.Request) bool { + for _, param := range params { + if req.Form.Get(param) != "" { + continue + } + errorResponse( + rw, + InvalidRequest, + fmt.Sprintf("The request is missing the required parameter: %s", param), + http.StatusBadRequest, + ) + return false + } + return true +} + +func assertEqual(param, value, errorType, errorMsg string, rw http.ResponseWriter, req *http.Request) bool { + formValue := req.Form.Get(param) + if subtle.ConstantTimeCompare([]byte(value), []byte(formValue)) == 0 { + errorResponse(rw, errorType, fmt.Sprintf("%s: %s", errorMsg, formValue), + http.StatusUnauthorized) + return false + } + return true +} + +func validateScope(rw http.ResponseWriter, req *http.Request) bool { + allowed := make(map[string]struct{}) + for _, scope := range ScopesSupported { + allowed[scope] = struct{}{} + } + + scopes := strings.Split(req.Form.Get("scope"), " ") + for _, scope := range scopes { + if _, ok := allowed[scope]; !ok { + errorResponse(rw, InvalidScope, fmt.Sprintf("Unsupported scope: %s", scope), + http.StatusBadRequest) + return false + } + } + return true +} + +func validateCodeChallengeMethodSupported(rw http.ResponseWriter, method string, supportedMethods []string) bool { + if method != "" && !contains(method, supportedMethods) { + errorResponse(rw, InvalidRequest, "Invalid code challenge method", http.StatusBadRequest) + return false + } + return true +} + +func errorResponse(rw http.ResponseWriter, error, description string, statusCode int) { + errJSON := map[string]string{ + "error": error, + "error_description": description, + } + resp, err := json.Marshal(errJSON) + if err != nil { + http.Error(rw, error, http.StatusInternalServerError) + } + + noCache(rw) + rw.Header().Set("Content-Type", applicationJSON) + rw.WriteHeader(statusCode) + + _, err = rw.Write(resp) + if err != nil { + panic(err) + } +} + +func internalServerError(rw http.ResponseWriter, errorMsg string) { + errorResponse(rw, InternalServerError, errorMsg, http.StatusInternalServerError) +} + +func jsonResponse(rw http.ResponseWriter, data []byte) { + noCache(rw) + rw.Header().Set("Content-Type", applicationJSON) + rw.WriteHeader(http.StatusOK) + + _, err := rw.Write(data) + if err != nil { + panic(err) + } +} + +func noCache(rw http.ResponseWriter) { + rw.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, max-age=0") + rw.Header().Set("Pragma", "no-cache") +} + +func contains(value string, list []string) bool { + for _, element := range list { + if element == value { + return true + } + } + return false +} |
