From 5ffc9b007ccbd8a4510b58de72aaee53291d7973 Mon Sep 17 00:00:00 2001 From: mo khan Date: Wed, 11 Jun 2025 17:11:39 -0600 Subject: refactor: apply SOLID principles --- src/oauth/mod.rs | 4 +- src/oauth/pkce.rs | 18 +- src/oauth/server.rs | 8 +- src/oauth/service.rs | 566 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/oauth/types.rs | 17 ++ 5 files changed, 600 insertions(+), 13 deletions(-) create mode 100644 src/oauth/service.rs (limited to 'src/oauth') diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 4b18bb3..b2d46fa 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -1,9 +1,11 @@ pub mod pkce; pub mod server; +pub mod service; pub mod types; pub use pkce::{ - generate_code_challenge, generate_code_verifier, verify_code_challenge, CodeChallengeMethod, + CodeChallengeMethod, generate_code_challenge, generate_code_verifier, verify_code_challenge, }; pub use server::OAuthServer; +pub use service::OAuthService; pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse}; diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs index 406d364..0dfc1f8 100644 --- a/src/oauth/pkce.rs +++ b/src/oauth/pkce.rs @@ -1,5 +1,5 @@ -use anyhow::{anyhow, Result}; -use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use anyhow::{Result, anyhow}; +use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; use sha2::{Digest, Sha256}; #[derive(Debug, Clone, PartialEq)] @@ -124,12 +124,14 @@ mod tests { assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err()); // Invalid characters - assert!(verify_code_challenge( - "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", - "challenge", - &CodeChallengeMethod::Plain - ) - .is_err()); + assert!( + verify_code_challenge( + "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", + "challenge", + &CodeChallengeMethod::Plain + ) + .is_err() + ); } #[test] diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 7fd8b9c..37c3cbc 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,12 +1,12 @@ -use crate::clients::{parse_basic_auth, ClientManager}; +use crate::clients::{ClientManager, parse_basic_auth}; use crate::config::Config; use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode}; use crate::keys::KeyManager; -use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod}; +use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse}; -use anyhow::{anyhow, Result}; +use anyhow::{Result, anyhow}; use chrono::{Duration, Utc}; -use jsonwebtoken::{encode, Algorithm, Header}; +use jsonwebtoken::{Algorithm, Header, encode}; use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/src/oauth/service.rs b/src/oauth/service.rs new file mode 100644 index 0000000..1b4eb49 --- /dev/null +++ b/src/oauth/service.rs @@ -0,0 +1,566 @@ +use crate::container::ServiceContainer; +use crate::database::{DbAccessToken, DbAuthCode}; +use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; +use crate::oauth::types::{ErrorResponse, TokenIntrospectionResponse, TokenResponse}; +use anyhow::Result; +use chrono::{Duration, Utc}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::Arc; +use url::Url; +use uuid::Uuid; + +/// Refactored OAuth service using dependency injection +pub struct OAuthService { + container: Arc, +} + +impl OAuthService { + pub fn new(container: Arc) -> Self { + Self { container } + } + + pub fn get_jwks(&self) -> String { + self.container.get_jwks() + } + + pub fn handle_authorize( + &self, + params: &HashMap, + ip_address: Option, + ) -> Result { + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let redirect_uri = params + .get("redirect_uri") + .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; + + let response_type = params + .get("response_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + + // Rate limiting check + if let Err(e) = self + .container + .rate_limiter + .check_rate_limit(&format!("client:{}", client_id), "/authorize") + { + let _ = self.container.audit_logger.log_event( + "authorize_rate_limited", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + + // Validate client exists + let client = match self.container.client_repository.get_client(client_id) { + Ok(Some(client)) => client, + Ok(None) => { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_client", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_client", "Invalid client_id")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + }; + + // Validate redirect URI + let redirect_uris: Vec = + serde_json::from_str(&client.redirect_uris).unwrap_or_else(|_| vec![]); + + if !redirect_uris.contains(redirect_uri) { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_redirect_uri", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(redirect_uri), + ); + return Err(self.error_response("invalid_request", "Invalid redirect_uri")); + } + + // Validate requested scopes + let scope = params.get("scope").cloned(); + if let Some(ref scope_str) = scope { + let client_scopes: Vec<&str> = client.scopes.split_whitespace().collect(); + let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); + + for requested_scope in &requested_scopes { + if !client_scopes.contains(requested_scope) { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_scope", + Some(client_id), + None, + ip_address.as_deref(), + false, + scope.as_deref(), + ); + return Err(self.error_response("invalid_scope", "Invalid scope")); + } + } + } + + if response_type != "code" { + let _ = self.container.audit_logger.log_event( + "authorize_unsupported_response_type", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(response_type), + ); + return Err(self.error_response( + "unsupported_response_type", + "Only code response type supported", + )); + } + + // PKCE validation (RFC 7636) + let code_challenge = params.get("code_challenge"); + let code_challenge_method = params + .get("code_challenge_method") + .map(|method| CodeChallengeMethod::from_str(method)) + .transpose() + .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; + + // For public clients, PKCE is required + if client.client_id.starts_with("public_") && code_challenge.is_none() { + let _ = self.container.audit_logger.log_event( + "authorize_missing_pkce", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_request", "PKCE required for public clients")); + } + + let code = Uuid::new_v4().to_string(); + let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration + + let db_auth_code = DbAuthCode { + id: 0, // Will be set by database + code: code.clone(), + client_id: client_id.clone(), + user_id: "test_user".to_string(), // In a real implementation, this would come from authentication + redirect_uri: redirect_uri.clone(), + scope: scope.clone(), + expires_at, + created_at: Utc::now(), + is_used: false, + code_challenge: code_challenge.cloned(), + code_challenge_method: code_challenge_method + .as_ref() + .map(|m| m.as_str().to_string()), + }; + + // Save to database + if let Err(_) = self + .container + .auth_code_repository + .create_auth_code(&db_auth_code) + { + return Err(self.error_response("server_error", "Failed to create authorization code")); + } + + let mut redirect_url = Url::parse(redirect_uri) + .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; + + redirect_url.query_pairs_mut().append_pair("code", &code); + + if let Some(state) = params.get("state") { + redirect_url.query_pairs_mut().append_pair("state", state); + } + + let _ = self.container.audit_logger.log_event( + "authorize_success", + Some(client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); + + Ok(redirect_url.to_string()) + } + + pub fn handle_token( + &self, + params: &HashMap, + auth_header: Option<&str>, + ip_address: Option, + ) -> Result { + let grant_type = params + .get("grant_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; + + match grant_type.as_str() { + "authorization_code" => { + self.handle_authorization_code_grant(params, auth_header, ip_address) + } + "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), + _ => { + let _ = self.container.audit_logger.log_event( + "token_unsupported_grant_type", + None, + None, + ip_address.as_deref(), + false, + Some(grant_type), + ); + Err(self.error_response("unsupported_grant_type", "Unsupported grant type")) + } + } + } + + fn handle_authorization_code_grant( + &self, + params: &HashMap, + auth_header: Option<&str>, + ip_address: Option, + ) -> Result { + let code = params + .get("code") + .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; + + // Client authentication using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Rate limiting check + if let Err(e) = self + .container + .rate_limiter + .check_rate_limit(&format!("client:{}", client_id), "/token") + { + let _ = self.container.audit_logger.log_event( + "token_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + + // Get and validate authorization code + let auth_code = match self.container.auth_code_repository.get_auth_code(code) { + Ok(Some(auth_code)) => auth_code, + Ok(None) => { + let _ = self.container.audit_logger.log_event( + "token_invalid_code", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(code), + ); + return Err( + self.error_response("invalid_grant", "Invalid or expired authorization code") + ); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + }; + + // Validate code hasn't been used and hasn't expired + if auth_code.is_used { + let _ = self.container.audit_logger.log_event( + "token_code_reuse", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self.error_response("invalid_grant", "Authorization code already used")); + } + + if Utc::now() > auth_code.expires_at { + let _ = self.container.audit_logger.log_event( + "token_code_expired", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self.error_response("invalid_grant", "Authorization code expired")); + } + + if auth_code.client_id != client_id { + let _ = self.container.audit_logger.log_event( + "token_client_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + // PKCE validation if code challenge was provided + if let Some(code_challenge) = &auth_code.code_challenge { + let code_verifier = params.get("code_verifier").ok_or_else(|| { + self.error_response("invalid_request", "Missing code_verifier for PKCE") + })?; + + let challenge_method = auth_code + .code_challenge_method + .as_ref() + .and_then(|method| CodeChallengeMethod::from_str(method).ok()) + .unwrap_or(CodeChallengeMethod::Plain); + + if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) + { + let _ = self.container.audit_logger.log_event( + "token_pkce_verification_failed", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "PKCE verification failed")); + } + } + + // Mark code as used + if let Err(_) = self + .container + .auth_code_repository + .mark_auth_code_used(code) + { + return Err(self.error_response("server_error", "Failed to mark code as used")); + } + + // Generate tokens using injected service + let token_id = Uuid::new_v4().to_string(); + let access_token = self.container.token_generator.generate_access_token( + &auth_code.user_id, + &client_id, + &auth_code.scope, + &token_id, + )?; + let refresh_token = self.container.token_generator.generate_refresh_token( + &client_id, + &auth_code.user_id, + &auth_code.scope, + )?; + + // Store token in database for revocation/introspection + let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); + let db_access_token = DbAccessToken { + id: 0, + token_id: token_id.clone(), + client_id: client_id.clone(), + user_id: auth_code.user_id.clone(), + scope: auth_code.scope.clone(), + expires_at: Utc::now() + Duration::hours(1), + created_at: Utc::now(), + is_revoked: false, + token_hash, + }; + + if let Err(_) = self + .container + .token_repository + .create_access_token(&db_access_token) + { + return Err(self.error_response("server_error", "Failed to store access token")); + } + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: Some(refresh_token), + scope: auth_code.scope, + }; + + let _ = self.container.audit_logger.log_event( + "token_success", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + true, + None, + ); + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + fn handle_refresh_token_grant( + &self, + params: &HashMap, + auth_header: Option<&str>, + ip_address: Option, + ) -> Result { + let _refresh_token = params + .get("refresh_token") + .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; + + // Client authentication using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Validate refresh token (implementation would verify token and get user info) + // For now, return a simple refresh token response + let new_token_id = Uuid::new_v4().to_string(); + let access_token = self.container.token_generator.generate_access_token( + "test_user", + &client_id, + &None, + &new_token_id, + )?; + let new_refresh_token = self.container.token_generator.generate_refresh_token( + &client_id, + "test_user", + &None, + )?; + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: Some(new_refresh_token), + scope: None, + }; + + let _ = self.container.audit_logger.log_event( + "refresh_success", + Some(&client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + pub fn handle_token_introspection( + &self, + params: &HashMap, + auth_header: Option<&str>, + ) -> Result { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the introspection request using injected service + let (_client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Look up token in database using repository + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + let db_token = self + .container + .token_repository + .get_access_token(&token_hash) + .ok() + .flatten(); + + let response = if let Some(db_token) = db_token { + if !db_token.is_revoked && Utc::now() < db_token.expires_at { + TokenIntrospectionResponse { + active: true, + client_id: Some(db_token.client_id.clone()), + username: Some(db_token.user_id.clone()), + scope: db_token.scope.clone(), + exp: Some(db_token.expires_at.timestamp() as u64), + iat: Some(db_token.created_at.timestamp() as u64), + sub: Some(db_token.user_id), + aud: Some(db_token.client_id), + iss: Some(self.container.config.issuer_url.clone()), + jti: Some(db_token.token_id), + } + } else { + TokenIntrospectionResponse::inactive() + } + } else { + TokenIntrospectionResponse::inactive() + }; + + serde_json::to_string(&response) + .map_err(|_| self.error_response("server_error", "Failed to serialize response")) + } + + pub fn handle_token_revocation( + &self, + params: &HashMap, + auth_header: Option<&str>, + ) -> Result<(), String> { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the revocation request using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Revoke token in database using repository + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + let _ = self + .container + .token_repository + .revoke_access_token(&token_hash); // Ignore errors as per RFC 7009 + + let _ = self.container.audit_logger.log_event( + "token_revoked", + Some(&client_id), + None, + None, + true, + None, + ); + + Ok(()) + } + + fn error_response(&self, error: &str, description: &str) -> String { + let error_resp = ErrorResponse { + error: error.to_string(), + error_description: Some(description.to_string()), + error_uri: None, + }; + serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) + } + + /// Cleanup expired data using repositories + pub fn cleanup_expired_data(&self) -> Result<()> { + self.container.cleanup_expired_data() + } +} diff --git a/src/oauth/types.rs b/src/oauth/types.rs index 4f2c363..3d1c581 100644 --- a/src/oauth/types.rs +++ b/src/oauth/types.rs @@ -76,6 +76,23 @@ pub struct TokenIntrospectionResponse { pub jti: Option, } +impl TokenIntrospectionResponse { + pub fn inactive() -> Self { + Self { + active: false, + client_id: None, + username: None, + scope: None, + exp: None, + iat: None, + sub: None, + aud: None, + iss: None, + jti: None, + } + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct TokenRevocationRequest { pub token: String, -- cgit v1.2.3