use crate::container::ServiceContainer; use crate::database::{DbAccessToken, DbAuthCode}; use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; use crate::oauth::types::{ErrorResponse, TokenIntrospectionResponse, TokenResponse}; use anyhow::Result; use chrono::{Duration, Utc}; use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::sync::Arc; use url::Url; use uuid::Uuid; /// Refactored OAuth service using dependency injection pub struct OAuthService { container: Arc, } impl OAuthService { pub fn new(container: Arc) -> Self { Self { container } } pub fn get_jwks(&self) -> String { self.container.get_jwks() } pub fn handle_authorize( &self, params: &HashMap, ip_address: Option, ) -> Result { let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; let redirect_uri = params .get("redirect_uri") .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; let response_type = params .get("response_type") .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; // Rate limiting check if let Err(e) = self .container .rate_limiter .check_rate_limit(&format!("client:{}", client_id), "/authorize") { let _ = self.container.audit_logger.log_event( "authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string()), ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } // Validate client exists let client = match self.container.client_repository.get_client(client_id) { Ok(Some(client)) => client, Ok(None) => { let _ = self.container.audit_logger.log_event( "authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_client", "Invalid client_id")); } Err(_) => { return Err(self.error_response("server_error", "Internal server error")); } }; // Validate redirect URI let redirect_uris: Vec = serde_json::from_str(&client.redirect_uris).unwrap_or_else(|_| vec![]); if !redirect_uris.contains(redirect_uri) { let _ = self.container.audit_logger.log_event( "authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri), ); return Err(self.error_response("invalid_request", "Invalid redirect_uri")); } // Validate requested scopes let scope = params.get("scope").cloned(); if let Some(ref scope_str) = scope { let client_scopes: Vec<&str> = client.scopes.split_whitespace().collect(); let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); for requested_scope in &requested_scopes { if !client_scopes.contains(requested_scope) { let _ = self.container.audit_logger.log_event( "authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref(), ); return Err(self.error_response("invalid_scope", "Invalid scope")); } } } if response_type != "code" { let _ = self.container.audit_logger.log_event( "authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type), ); return Err(self.error_response( "unsupported_response_type", "Only code response type supported", )); } // PKCE validation (RFC 7636) let code_challenge = params.get("code_challenge"); let code_challenge_method = params .get("code_challenge_method") .map(|method| CodeChallengeMethod::from_str(method)) .transpose() .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; // For public clients, PKCE is required if client.client_id.starts_with("public_") && code_challenge.is_none() { let _ = self.container.audit_logger.log_event( "authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_request", "PKCE required for public clients")); } let code = Uuid::new_v4().to_string(); let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration let db_auth_code = DbAuthCode { id: 0, // Will be set by database code: code.clone(), client_id: client_id.clone(), user_id: "test_user".to_string(), // In a real implementation, this would come from authentication redirect_uri: redirect_uri.clone(), scope: scope.clone(), expires_at, created_at: Utc::now(), is_used: false, code_challenge: code_challenge.cloned(), code_challenge_method: code_challenge_method .as_ref() .map(|m| m.as_str().to_string()), }; // Save to database if let Err(_) = self .container .auth_code_repository .create_auth_code(&db_auth_code) { return Err(self.error_response("server_error", "Failed to create authorization code")); } let mut redirect_url = Url::parse(redirect_uri) .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; redirect_url.query_pairs_mut().append_pair("code", &code); if let Some(state) = params.get("state") { redirect_url.query_pairs_mut().append_pair("state", state); } let _ = self.container.audit_logger.log_event( "authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None, ); Ok(redirect_url.to_string()) } pub fn handle_token( &self, params: &HashMap, auth_header: Option<&str>, ip_address: Option, ) -> Result { let grant_type = params .get("grant_type") .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; match grant_type.as_str() { "authorization_code" => { self.handle_authorization_code_grant(params, auth_header, ip_address) } "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), _ => { let _ = self.container.audit_logger.log_event( "token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type), ); Err(self.error_response("unsupported_grant_type", "Unsupported grant type")) } } } fn handle_authorization_code_grant( &self, params: &HashMap, auth_header: Option<&str>, ip_address: Option, ) -> Result { let code = params .get("code") .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; // Client authentication using injected service let (client_id, _client_secret) = self .container .client_authenticator .authenticate(params, auth_header) .map_err(|e| self.error_response("invalid_client", &e))?; // Rate limiting check if let Err(e) = self .container .rate_limiter .check_rate_limit(&format!("client:{}", client_id), "/token") { let _ = self.container.audit_logger.log_event( "token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string()), ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } // Get and validate authorization code let auth_code = match self.container.auth_code_repository.get_auth_code(code) { Ok(Some(auth_code)) => auth_code, Ok(None) => { let _ = self.container.audit_logger.log_event( "token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code), ); return Err( self.error_response("invalid_grant", "Invalid or expired authorization code") ); } Err(_) => { return Err(self.error_response("server_error", "Internal server error")); } }; // Validate code hasn't been used and hasn't expired if auth_code.is_used { let _ = self.container.audit_logger.log_event( "token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code), ); return Err(self.error_response("invalid_grant", "Authorization code already used")); } if Utc::now() > auth_code.expires_at { let _ = self.container.audit_logger.log_event( "token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code), ); return Err(self.error_response("invalid_grant", "Authorization code expired")); } if auth_code.client_id != client_id { let _ = self.container.audit_logger.log_event( "token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_grant", "Client ID mismatch")); } // PKCE validation if code challenge was provided if let Some(code_challenge) = &auth_code.code_challenge { let code_verifier = params.get("code_verifier").ok_or_else(|| { self.error_response("invalid_request", "Missing code_verifier for PKCE") })?; let challenge_method = auth_code .code_challenge_method .as_ref() .and_then(|method| CodeChallengeMethod::from_str(method).ok()) .unwrap_or(CodeChallengeMethod::Plain); if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) { let _ = self.container.audit_logger.log_event( "token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_grant", "PKCE verification failed")); } } // Mark code as used if let Err(_) = self .container .auth_code_repository .mark_auth_code_used(code) { return Err(self.error_response("server_error", "Failed to mark code as used")); } // Generate tokens using injected service let token_id = Uuid::new_v4().to_string(); let access_token = self.container.token_generator.generate_access_token( &auth_code.user_id, &client_id, &auth_code.scope, &token_id, )?; let refresh_token = self.container.token_generator.generate_refresh_token( &client_id, &auth_code.user_id, &auth_code.scope, )?; // Store token in database for revocation/introspection let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); let db_access_token = DbAccessToken { id: 0, token_id: token_id.clone(), client_id: client_id.clone(), user_id: auth_code.user_id.clone(), scope: auth_code.scope.clone(), expires_at: Utc::now() + Duration::hours(1), created_at: Utc::now(), is_revoked: false, token_hash, }; if let Err(_) = self .container .token_repository .create_access_token(&db_access_token) { return Err(self.error_response("server_error", "Failed to store access token")); } let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, refresh_token: Some(refresh_token), scope: auth_code.scope, }; let _ = self.container.audit_logger.log_event( "token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None, ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) } fn handle_refresh_token_grant( &self, params: &HashMap, auth_header: Option<&str>, ip_address: Option, ) -> Result { let _refresh_token = params .get("refresh_token") .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; // Client authentication using injected service let (client_id, _client_secret) = self .container .client_authenticator .authenticate(params, auth_header) .map_err(|e| self.error_response("invalid_client", &e))?; // Validate refresh token (implementation would verify token and get user info) // For now, return a simple refresh token response let new_token_id = Uuid::new_v4().to_string(); let access_token = self.container.token_generator.generate_access_token( "test_user", &client_id, &None, &new_token_id, )?; let new_refresh_token = self.container.token_generator.generate_refresh_token( &client_id, "test_user", &None, )?; let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, refresh_token: Some(new_refresh_token), scope: None, }; let _ = self.container.audit_logger.log_event( "refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None, ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) } pub fn handle_token_introspection( &self, params: &HashMap, auth_header: Option<&str>, ) -> Result { let token = params .get("token") .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; // Authenticate the client making the introspection request using injected service let (_client_id, _client_secret) = self .container .client_authenticator .authenticate(params, auth_header) .map_err(|e| self.error_response("invalid_client", &e))?; // Look up token in database using repository let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); let db_token = self .container .token_repository .get_access_token(&token_hash) .ok() .flatten(); let response = if let Some(db_token) = db_token { if !db_token.is_revoked && Utc::now() < db_token.expires_at { TokenIntrospectionResponse { active: true, client_id: Some(db_token.client_id.clone()), username: Some(db_token.user_id.clone()), scope: db_token.scope.clone(), exp: Some(db_token.expires_at.timestamp() as u64), iat: Some(db_token.created_at.timestamp() as u64), sub: Some(db_token.user_id), aud: Some(db_token.client_id), iss: Some(self.container.config.issuer_url.clone()), jti: Some(db_token.token_id), } } else { TokenIntrospectionResponse::inactive() } } else { TokenIntrospectionResponse::inactive() }; serde_json::to_string(&response) .map_err(|_| self.error_response("server_error", "Failed to serialize response")) } pub fn handle_token_revocation( &self, params: &HashMap, auth_header: Option<&str>, ) -> Result<(), String> { let token = params .get("token") .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; // Authenticate the client making the revocation request using injected service let (client_id, _client_secret) = self .container .client_authenticator .authenticate(params, auth_header) .map_err(|e| self.error_response("invalid_client", &e))?; // Revoke token in database using repository let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); let _ = self .container .token_repository .revoke_access_token(&token_hash); // Ignore errors as per RFC 7009 let _ = self.container.audit_logger.log_event( "token_revoked", Some(&client_id), None, None, true, None, ); Ok(()) } fn error_response(&self, error: &str, description: &str) -> String { let error_resp = ErrorResponse { error: error.to_string(), error_description: Some(description.to_string()), error_uri: None, }; serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) } /// Cleanup expired data using repositories pub fn cleanup_expired_data(&self) -> Result<()> { self.container.cleanup_expired_data() } }