diff options
Diffstat (limited to 'src/oauth/service.rs')
| -rw-r--r-- | src/oauth/service.rs | 566 |
1 files changed, 566 insertions, 0 deletions
diff --git a/src/oauth/service.rs b/src/oauth/service.rs new file mode 100644 index 0000000..1b4eb49 --- /dev/null +++ b/src/oauth/service.rs @@ -0,0 +1,566 @@ +use crate::container::ServiceContainer; +use crate::database::{DbAccessToken, DbAuthCode}; +use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; +use crate::oauth::types::{ErrorResponse, TokenIntrospectionResponse, TokenResponse}; +use anyhow::Result; +use chrono::{Duration, Utc}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::Arc; +use url::Url; +use uuid::Uuid; + +/// Refactored OAuth service using dependency injection +pub struct OAuthService { + container: Arc<ServiceContainer>, +} + +impl OAuthService { + pub fn new(container: Arc<ServiceContainer>) -> Self { + Self { container } + } + + pub fn get_jwks(&self) -> String { + self.container.get_jwks() + } + + pub fn handle_authorize( + &self, + params: &HashMap<String, String>, + ip_address: Option<String>, + ) -> Result<String, String> { + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let redirect_uri = params + .get("redirect_uri") + .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; + + let response_type = params + .get("response_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + + // Rate limiting check + if let Err(e) = self + .container + .rate_limiter + .check_rate_limit(&format!("client:{}", client_id), "/authorize") + { + let _ = self.container.audit_logger.log_event( + "authorize_rate_limited", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + + // Validate client exists + let client = match self.container.client_repository.get_client(client_id) { + Ok(Some(client)) => client, + Ok(None) => { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_client", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_client", "Invalid client_id")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + }; + + // Validate redirect URI + let redirect_uris: Vec<String> = + serde_json::from_str(&client.redirect_uris).unwrap_or_else(|_| vec![]); + + if !redirect_uris.contains(redirect_uri) { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_redirect_uri", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(redirect_uri), + ); + return Err(self.error_response("invalid_request", "Invalid redirect_uri")); + } + + // Validate requested scopes + let scope = params.get("scope").cloned(); + if let Some(ref scope_str) = scope { + let client_scopes: Vec<&str> = client.scopes.split_whitespace().collect(); + let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); + + for requested_scope in &requested_scopes { + if !client_scopes.contains(requested_scope) { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_scope", + Some(client_id), + None, + ip_address.as_deref(), + false, + scope.as_deref(), + ); + return Err(self.error_response("invalid_scope", "Invalid scope")); + } + } + } + + if response_type != "code" { + let _ = self.container.audit_logger.log_event( + "authorize_unsupported_response_type", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(response_type), + ); + return Err(self.error_response( + "unsupported_response_type", + "Only code response type supported", + )); + } + + // PKCE validation (RFC 7636) + let code_challenge = params.get("code_challenge"); + let code_challenge_method = params + .get("code_challenge_method") + .map(|method| CodeChallengeMethod::from_str(method)) + .transpose() + .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; + + // For public clients, PKCE is required + if client.client_id.starts_with("public_") && code_challenge.is_none() { + let _ = self.container.audit_logger.log_event( + "authorize_missing_pkce", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_request", "PKCE required for public clients")); + } + + let code = Uuid::new_v4().to_string(); + let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration + + let db_auth_code = DbAuthCode { + id: 0, // Will be set by database + code: code.clone(), + client_id: client_id.clone(), + user_id: "test_user".to_string(), // In a real implementation, this would come from authentication + redirect_uri: redirect_uri.clone(), + scope: scope.clone(), + expires_at, + created_at: Utc::now(), + is_used: false, + code_challenge: code_challenge.cloned(), + code_challenge_method: code_challenge_method + .as_ref() + .map(|m| m.as_str().to_string()), + }; + + // Save to database + if let Err(_) = self + .container + .auth_code_repository + .create_auth_code(&db_auth_code) + { + return Err(self.error_response("server_error", "Failed to create authorization code")); + } + + let mut redirect_url = Url::parse(redirect_uri) + .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; + + redirect_url.query_pairs_mut().append_pair("code", &code); + + if let Some(state) = params.get("state") { + redirect_url.query_pairs_mut().append_pair("state", state); + } + + let _ = self.container.audit_logger.log_event( + "authorize_success", + Some(client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); + + Ok(redirect_url.to_string()) + } + + pub fn handle_token( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { + let grant_type = params + .get("grant_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; + + match grant_type.as_str() { + "authorization_code" => { + self.handle_authorization_code_grant(params, auth_header, ip_address) + } + "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), + _ => { + let _ = self.container.audit_logger.log_event( + "token_unsupported_grant_type", + None, + None, + ip_address.as_deref(), + false, + Some(grant_type), + ); + Err(self.error_response("unsupported_grant_type", "Unsupported grant type")) + } + } + } + + fn handle_authorization_code_grant( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { + let code = params + .get("code") + .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; + + // Client authentication using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Rate limiting check + if let Err(e) = self + .container + .rate_limiter + .check_rate_limit(&format!("client:{}", client_id), "/token") + { + let _ = self.container.audit_logger.log_event( + "token_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + + // Get and validate authorization code + let auth_code = match self.container.auth_code_repository.get_auth_code(code) { + Ok(Some(auth_code)) => auth_code, + Ok(None) => { + let _ = self.container.audit_logger.log_event( + "token_invalid_code", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(code), + ); + return Err( + self.error_response("invalid_grant", "Invalid or expired authorization code") + ); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + }; + + // Validate code hasn't been used and hasn't expired + if auth_code.is_used { + let _ = self.container.audit_logger.log_event( + "token_code_reuse", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self.error_response("invalid_grant", "Authorization code already used")); + } + + if Utc::now() > auth_code.expires_at { + let _ = self.container.audit_logger.log_event( + "token_code_expired", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self.error_response("invalid_grant", "Authorization code expired")); + } + + if auth_code.client_id != client_id { + let _ = self.container.audit_logger.log_event( + "token_client_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + // PKCE validation if code challenge was provided + if let Some(code_challenge) = &auth_code.code_challenge { + let code_verifier = params.get("code_verifier").ok_or_else(|| { + self.error_response("invalid_request", "Missing code_verifier for PKCE") + })?; + + let challenge_method = auth_code + .code_challenge_method + .as_ref() + .and_then(|method| CodeChallengeMethod::from_str(method).ok()) + .unwrap_or(CodeChallengeMethod::Plain); + + if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) + { + let _ = self.container.audit_logger.log_event( + "token_pkce_verification_failed", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "PKCE verification failed")); + } + } + + // Mark code as used + if let Err(_) = self + .container + .auth_code_repository + .mark_auth_code_used(code) + { + return Err(self.error_response("server_error", "Failed to mark code as used")); + } + + // Generate tokens using injected service + let token_id = Uuid::new_v4().to_string(); + let access_token = self.container.token_generator.generate_access_token( + &auth_code.user_id, + &client_id, + &auth_code.scope, + &token_id, + )?; + let refresh_token = self.container.token_generator.generate_refresh_token( + &client_id, + &auth_code.user_id, + &auth_code.scope, + )?; + + // Store token in database for revocation/introspection + let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); + let db_access_token = DbAccessToken { + id: 0, + token_id: token_id.clone(), + client_id: client_id.clone(), + user_id: auth_code.user_id.clone(), + scope: auth_code.scope.clone(), + expires_at: Utc::now() + Duration::hours(1), + created_at: Utc::now(), + is_revoked: false, + token_hash, + }; + + if let Err(_) = self + .container + .token_repository + .create_access_token(&db_access_token) + { + return Err(self.error_response("server_error", "Failed to store access token")); + } + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: Some(refresh_token), + scope: auth_code.scope, + }; + + let _ = self.container.audit_logger.log_event( + "token_success", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + true, + None, + ); + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + fn handle_refresh_token_grant( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { + let _refresh_token = params + .get("refresh_token") + .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; + + // Client authentication using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Validate refresh token (implementation would verify token and get user info) + // For now, return a simple refresh token response + let new_token_id = Uuid::new_v4().to_string(); + let access_token = self.container.token_generator.generate_access_token( + "test_user", + &client_id, + &None, + &new_token_id, + )?; + let new_refresh_token = self.container.token_generator.generate_refresh_token( + &client_id, + "test_user", + &None, + )?; + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: Some(new_refresh_token), + scope: None, + }; + + let _ = self.container.audit_logger.log_event( + "refresh_success", + Some(&client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + pub fn handle_token_introspection( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<String, String> { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the introspection request using injected service + let (_client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Look up token in database using repository + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + let db_token = self + .container + .token_repository + .get_access_token(&token_hash) + .ok() + .flatten(); + + let response = if let Some(db_token) = db_token { + if !db_token.is_revoked && Utc::now() < db_token.expires_at { + TokenIntrospectionResponse { + active: true, + client_id: Some(db_token.client_id.clone()), + username: Some(db_token.user_id.clone()), + scope: db_token.scope.clone(), + exp: Some(db_token.expires_at.timestamp() as u64), + iat: Some(db_token.created_at.timestamp() as u64), + sub: Some(db_token.user_id), + aud: Some(db_token.client_id), + iss: Some(self.container.config.issuer_url.clone()), + jti: Some(db_token.token_id), + } + } else { + TokenIntrospectionResponse::inactive() + } + } else { + TokenIntrospectionResponse::inactive() + }; + + serde_json::to_string(&response) + .map_err(|_| self.error_response("server_error", "Failed to serialize response")) + } + + pub fn handle_token_revocation( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<(), String> { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the revocation request using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Revoke token in database using repository + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + let _ = self + .container + .token_repository + .revoke_access_token(&token_hash); // Ignore errors as per RFC 7009 + + let _ = self.container.audit_logger.log_event( + "token_revoked", + Some(&client_id), + None, + None, + true, + None, + ); + + Ok(()) + } + + fn error_response(&self, error: &str, description: &str) -> String { + let error_resp = ErrorResponse { + error: error.to_string(), + error_description: Some(description.to_string()), + error_uri: None, + }; + serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) + } + + /// Cleanup expired data using repositories + pub fn cleanup_expired_data(&self) -> Result<()> { + self.container.cleanup_expired_data() + } +} |
