From 5ffc9b007ccbd8a4510b58de72aaee53291d7973 Mon Sep 17 00:00:00 2001 From: mo khan Date: Wed, 11 Jun 2025 17:11:39 -0600 Subject: refactor: apply SOLID principles --- src/services/implementations.rs | 217 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 217 insertions(+) create mode 100644 src/services/implementations.rs (limited to 'src/services/implementations.rs') diff --git a/src/services/implementations.rs b/src/services/implementations.rs new file mode 100644 index 0000000..ff03165 --- /dev/null +++ b/src/services/implementations.rs @@ -0,0 +1,217 @@ +use super::*; +use crate::clients::parse_basic_auth; +use crate::config::Config; +use crate::database::DbAuditLog; +use crate::keys::KeyManager; +use crate::oauth::types::Claims; +use crate::repositories::{AuditRepository, ClientRepository, RateRepository}; +use anyhow::{Result, anyhow}; +use chrono::Utc; +use jsonwebtoken::{Algorithm, Header, encode}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; +use uuid::Uuid; + +/// Default implementation of ClientAuthenticator +pub struct DefaultClientAuthenticator { + client_repository: Arc, +} + +impl DefaultClientAuthenticator { + pub fn new(client_repository: Arc) -> Self { + Self { client_repository } + } + + fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool { + match self.client_repository.get_client(client_id) { + Ok(Some(client)) => { + // Use constant-time comparison to prevent timing attacks + use subtle::ConstantTimeEq; + let expected_hash = client.client_secret_hash.as_bytes(); + let provided_hash = self.hash_client_secret(client_secret); + expected_hash.ct_eq(provided_hash.as_bytes()).into() + } + _ => false, + } + } + + fn hash_client_secret(&self, secret: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(secret.as_bytes()); + format!("{:x}", hasher.finalize()) + } +} + +impl ClientAuthenticator for DefaultClientAuthenticator { + fn authenticate( + &self, + params: &HashMap, + auth_header: Option<&str>, + ) -> Result<(String, String), String> { + if let Some(auth_header) = auth_header { + // HTTP Basic Authentication (preferred method) + parse_basic_auth(auth_header).ok_or_else(|| "Invalid Authorization header".to_string()) + } else { + // Form-based authentication (fallback) + let client_id = params + .get("client_id") + .ok_or_else(|| "Missing client_id".to_string())?; + let client_secret = params + .get("client_secret") + .ok_or_else(|| "Missing client_secret".to_string())?; + + if !self.authenticate_client(client_id, client_secret) { + return Err("Invalid client credentials".to_string()); + } + + Ok((client_id.clone(), client_secret.clone())) + } + } +} + +/// Default implementation of RateLimiter +pub struct DefaultRateLimiter { + rate_repository: Arc, + config: Config, +} + +impl DefaultRateLimiter { + pub fn new(rate_repository: Arc, config: Config) -> Self { + Self { + rate_repository, + config, + } + } +} + +impl RateLimiter for DefaultRateLimiter { + fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { + let count = self + .rate_repository + .increment_rate_limit(identifier, endpoint, 1)?; + + if count > self.config.rate_limit_requests_per_minute as i32 { + return Err(anyhow!("Rate limit exceeded")); + } + + Ok(()) + } +} + +/// Default implementation of AuditLogger +pub struct DefaultAuditLogger { + audit_repository: Arc, + config: Config, +} + +impl DefaultAuditLogger { + pub fn new(audit_repository: Arc, config: Config) -> Self { + Self { + audit_repository, + config, + } + } +} + +impl AuditLogger for DefaultAuditLogger { + fn log_event( + &self, + event_type: &str, + client_id: Option<&str>, + user_id: Option<&str>, + ip_address: Option<&str>, + success: bool, + details: Option<&str>, + ) -> Result<()> { + if !self.config.enable_audit_logging { + return Ok(()); + } + + let log = DbAuditLog { + id: 0, + event_type: event_type.to_string(), + client_id: client_id.map(|s| s.to_string()), + user_id: user_id.map(|s| s.to_string()), + ip_address: ip_address.map(|s| s.to_string()), + user_agent: None, // Could be passed in from HTTP layer + details: details.map(|s| s.to_string()), + created_at: Utc::now(), + success, + }; + + self.audit_repository.create_audit_log(&log)?; + Ok(()) + } +} + +/// Default implementation of TokenGenerator +pub struct DefaultTokenGenerator { + key_manager: Arc>, + config: Config, +} + +impl DefaultTokenGenerator { + pub fn new(key_manager: Arc>, config: Config) -> Self { + Self { + key_manager, + config, + } + } +} + +impl TokenGenerator for DefaultTokenGenerator { + fn generate_access_token( + &self, + user_id: &str, + client_id: &str, + scope: &Option, + token_id: &str, + ) -> Result { + let mut key_manager = self.key_manager.lock().unwrap(); + + // Check if we need to rotate keys + if key_manager.should_rotate() { + if let Err(_) = key_manager.rotate_keys() { + return Err("Key rotation failed".to_string()); + } + } + + let current_key = key_manager + .get_current_key() + .ok_or_else(|| "No signing key available".to_string())?; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let claims = Claims { + sub: user_id.to_string(), + iss: self.config.issuer_url.clone(), + aud: client_id.to_string(), + exp: now + 3600, + iat: now, + scope: scope.clone(), + jti: Some(token_id.to_string()), + }; + + let mut header = Header::new(Algorithm::RS256); + header.kid = Some(current_key.kid.clone()); + + encode(&header, &claims, ¤t_key.encoding_key) + .map_err(|_| "Failed to generate token".to_string()) + } + + fn generate_refresh_token( + &self, + _client_id: &str, + _user_id: &str, + _scope: &Option, + ) -> Result { + // For now, return a simple UUID-based refresh token + // In production, this should be a proper JWT or encrypted token + Ok(Uuid::new_v4().to_string()) + } +} -- cgit v1.2.3