diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 17:11:39 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 17:11:39 -0600 |
| commit | 5ffc9b007ccbd8a4510b58de72aaee53291d7973 (patch) | |
| tree | f696a2a7599926d402c5456c434bd87e5e325c3a /src | |
| parent | dbd3c780f27bd5bee23adf6e280b84d669230e0d (diff) | |
refactor: apply SOLID principles
Diffstat (limited to 'src')
| -rw-r--r-- | src/bin/debug.rs | 28 | ||||
| -rw-r--r-- | src/container.rs | 103 | ||||
| -rw-r--r-- | src/database.rs | 117 | ||||
| -rw-r--r-- | src/http/mod.rs | 86 | ||||
| -rw-r--r-- | src/lib.rs | 5 | ||||
| -rw-r--r-- | src/main.rs | 134 | ||||
| -rw-r--r-- | src/oauth/mod.rs | 4 | ||||
| -rw-r--r-- | src/oauth/pkce.rs | 18 | ||||
| -rw-r--r-- | src/oauth/server.rs | 8 | ||||
| -rw-r--r-- | src/oauth/service.rs | 566 | ||||
| -rw-r--r-- | src/oauth/types.rs | 17 | ||||
| -rw-r--r-- | src/repositories/mod.rs | 52 | ||||
| -rw-r--r-- | src/repositories/sqlite.rs | 164 | ||||
| -rw-r--r-- | src/services/implementations.rs | 217 | ||||
| -rw-r--r-- | src/services/mod.rs | 49 |
15 files changed, 1490 insertions, 78 deletions
diff --git a/src/bin/debug.rs b/src/bin/debug.rs index 6d80848..e05446b 100644 --- a/src/bin/debug.rs +++ b/src/bin/debug.rs @@ -1,21 +1,25 @@ fn main() { let config = sts::Config::from_env(); println!("Config loaded: {}", config.bind_addr); + + // Try the old-style server creation let server = sts::http::Server::new(config.clone()); println!("Server result: {:?}", server.is_ok()); - if let Ok(server) = server { - let oauth_server = &server.oauth_server; - let jwks = oauth_server.get_jwks(); - println!("JWKS length: {}", jwks.len()); - println!( - "JWKS: {}", - if jwks.len() > 100 { - &jwks[..100] - } else { - &jwks - } - ); + if let Ok(_server) = server { + // Create OAuth server directly to test JWKS + if let Ok(oauth_server) = sts::OAuthServer::new(&config) { + let jwks = oauth_server.get_jwks(); + println!("JWKS length: {}", jwks.len()); + println!( + "JWKS: {}", + if jwks.len() > 100 { + &jwks[..100] + } else { + &jwks + } + ); + } } let metadata = serde_json::json!({ diff --git a/src/container.rs b/src/container.rs new file mode 100644 index 0000000..3a4b13e --- /dev/null +++ b/src/container.rs @@ -0,0 +1,103 @@ +use crate::config::Config; +use crate::database::Database; +use crate::keys::KeyManager; +use crate::repositories::*; +use crate::services::implementations::*; +use crate::services::*; +use anyhow::Result; +use std::sync::{Arc, Mutex}; + +/// Dependency injection container for all services and repositories +pub struct ServiceContainer { + // Repositories + pub client_repository: Arc<dyn ClientRepository>, + pub auth_code_repository: Arc<dyn AuthCodeRepository>, + pub token_repository: Arc<dyn TokenRepository>, + pub audit_repository: Arc<dyn AuditRepository>, + pub rate_repository: Arc<dyn RateRepository>, + + // Services + pub client_authenticator: Arc<dyn ClientAuthenticator>, + pub rate_limiter: Arc<dyn RateLimiter>, + pub audit_logger: Arc<dyn AuditLogger>, + pub token_generator: Arc<dyn TokenGenerator>, + + // Core components + pub key_manager: Arc<Mutex<KeyManager>>, + pub config: Config, +} + +impl ServiceContainer { + pub fn new(config: Config, database: Arc<Mutex<Database>>) -> Result<Self> { + // Create repositories + let client_repository: Arc<dyn ClientRepository> = + Arc::new(SqliteClientRepository::new(database.clone())); + let auth_code_repository: Arc<dyn AuthCodeRepository> = + Arc::new(SqliteAuthCodeRepository::new(database.clone())); + let token_repository: Arc<dyn TokenRepository> = + Arc::new(SqliteTokenRepository::new(database.clone())); + let audit_repository: Arc<dyn AuditRepository> = + Arc::new(SqliteAuditRepository::new(database.clone())); + let rate_repository: Arc<dyn RateRepository> = + Arc::new(SqliteRateRepository::new(database.clone())); + + // Create key manager + let key_manager = Arc::new(Mutex::new(KeyManager::new(database.clone())?)); + + // Create services + let client_authenticator: Arc<dyn ClientAuthenticator> = + Arc::new(DefaultClientAuthenticator::new(client_repository.clone())); + let rate_limiter: Arc<dyn RateLimiter> = Arc::new(DefaultRateLimiter::new( + rate_repository.clone(), + config.clone(), + )); + let audit_logger: Arc<dyn AuditLogger> = Arc::new(DefaultAuditLogger::new( + audit_repository.clone(), + config.clone(), + )); + let token_generator: Arc<dyn TokenGenerator> = Arc::new(DefaultTokenGenerator::new( + key_manager.clone(), + config.clone(), + )); + + Ok(Self { + client_repository, + auth_code_repository, + token_repository, + audit_repository, + rate_repository, + client_authenticator, + rate_limiter, + audit_logger, + token_generator, + key_manager, + config, + }) + } + + /// Get JWKS from the key manager + pub fn get_jwks(&self) -> String { + let key_manager = self.key_manager.lock().unwrap(); + match key_manager.get_jwks() { + Ok(jwks) => serde_json::to_string(&jwks).unwrap_or_else(|_| "{}".to_string()), + Err(_) => serde_json::json!({"keys": []}).to_string(), + } + } + + /// Cleanup expired data + pub fn cleanup_expired_data(&self) -> Result<()> { + // Cleanup expired authorization codes + let _ = self.auth_code_repository.cleanup_expired_codes(); + + // Cleanup expired tokens + let _ = self.token_repository.cleanup_expired_tokens(); + + // Cleanup old audit logs (keep for 30 days) + let _ = self.audit_repository.cleanup_old_audit_logs(30); + + // Cleanup old rate limits + let _ = self.rate_repository.cleanup_old_rate_limits(); + + Ok(()) + } +} diff --git a/src/database.rs b/src/database.rs index 2472d1a..5251dac 100644 --- a/src/database.rs +++ b/src/database.rs @@ -665,6 +665,123 @@ impl Database { )?; Ok(affected) } + + // Additional methods needed for repository patterns + pub fn update_oauth_client(&self, client: &DbOAuthClient) -> Result<()> { + self.conn.execute( + "UPDATE oauth_clients SET + client_secret_hash = ?2, client_name = ?3, redirect_uris = ?4, + scopes = ?5, grant_types = ?6, response_types = ?7, + updated_at = ?8, is_active = ?9 + WHERE client_id = ?1", + params![ + client.client_id, + client.client_secret_hash, + client.client_name, + client.redirect_uris, + client.scopes, + client.grant_types, + client.response_types, + client.updated_at.to_rfc3339(), + client.is_active + ], + )?; + Ok(()) + } + + pub fn delete_oauth_client(&self, client_id: &str) -> Result<()> { + self.conn.execute( + "DELETE FROM oauth_clients WHERE client_id = ?1", + [client_id], + )?; + Ok(()) + } + + pub fn list_oauth_clients(&self) -> Result<Vec<DbOAuthClient>> { + let mut stmt = self.conn.prepare( + "SELECT id, client_id, client_secret_hash, client_name, redirect_uris, + scopes, grant_types, response_types, created_at, updated_at, is_active + FROM oauth_clients ORDER BY created_at DESC", + )?; + + let clients = stmt + .query_map([], |row| { + Ok(DbOAuthClient { + id: row.get(0)?, + client_id: row.get(1)?, + client_secret_hash: row.get(2)?, + client_name: row.get(3)?, + redirect_uris: row.get(4)?, + scopes: row.get(5)?, + grant_types: row.get(6)?, + response_types: row.get(7)?, + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?) + .map_err(|e| { + rusqlite::Error::FromSqlConversionFailure( + 8, + rusqlite::types::Type::Text, + Box::new(e), + ) + })? + .with_timezone(&Utc), + updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?) + .map_err(|e| { + rusqlite::Error::FromSqlConversionFailure( + 9, + rusqlite::types::Type::Text, + Box::new(e), + ) + })? + .with_timezone(&Utc), + is_active: row.get(10)?, + }) + })? + .collect::<Result<Vec<_>, _>>()?; + + Ok(clients) + } + + pub fn get_audit_logs(&self, limit: i32) -> Result<Vec<DbAuditLog>> { + let mut stmt = self.conn.prepare( + "SELECT id, event_type, client_id, user_id, ip_address, user_agent, details, created_at, success + FROM audit_logs ORDER BY created_at DESC LIMIT ?1" + )?; + + let logs = stmt + .query_map([limit], |row| { + Ok(DbAuditLog { + id: row.get(0)?, + event_type: row.get(1)?, + client_id: row.get(2)?, + user_id: row.get(3)?, + ip_address: row.get(4)?, + user_agent: row.get(5)?, + details: row.get(6)?, + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) + .map_err(|e| { + rusqlite::Error::FromSqlConversionFailure( + 7, + rusqlite::types::Type::Text, + Box::new(e), + ) + })? + .with_timezone(&Utc), + success: row.get(8)?, + }) + })? + .collect::<Result<Vec<_>, _>>()?; + + Ok(logs) + } + + pub fn cleanup_old_rate_limits(&self) -> Result<()> { + let cutoff = Utc::now() - chrono::Duration::hours(24); // Clean up rate limits older than 24 hours + self.conn.execute( + "DELETE FROM rate_limits WHERE created_at < ?1", + [cutoff.to_rfc3339()], + )?; + Ok(()) + } } #[cfg(test)] diff --git a/src/http/mod.rs b/src/http/mod.rs index 1bc7951..778a3de 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,21 +1,38 @@ use crate::config::Config; -use crate::oauth::OAuthServer; +use crate::container::ServiceContainer; +use crate::oauth::{OAuthServer, OAuthService}; use std::collections::HashMap; use std::fs; use std::io::prelude::*; use std::net::{TcpListener, TcpStream}; +use std::sync::Arc; use url::Url; pub struct Server { config: Config, - pub oauth_server: OAuthServer, + oauth_server: Option<OAuthServer>, + oauth_service: Option<Arc<ServiceContainer>>, } impl Server { pub fn new(config: Config) -> Result<Server, Box<dyn std::error::Error>> { Ok(Server { - oauth_server: OAuthServer::new(&config) - .map_err(|e| format!("Failed to create OAuth server: {}", e))?, + oauth_server: Some( + OAuthServer::new(&config) + .map_err(|e| format!("Failed to create OAuth server: {}", e))?, + ), + oauth_service: None, + config, + }) + } + + pub fn new_with_container( + config: Config, + container: Arc<ServiceContainer>, + ) -> Result<Server, Box<dyn std::error::Error>> { + Ok(Server { + oauth_server: None, + oauth_service: Some(container), config, }) } @@ -194,7 +211,13 @@ impl Server { } fn handle_jwks(&self, stream: &mut TcpStream) { - let jwks = self.oauth_server.get_jwks(); + let jwks = if let Some(ref oauth_server) = self.oauth_server { + oauth_server.get_jwks() + } else if let Some(ref container) = self.oauth_service { + container.get_jwks() + } else { + "{\"keys\":[]}".to_string() + }; self.send_json_response(stream, 200, "OK", &jwks); } @@ -204,7 +227,16 @@ impl Server { params: &HashMap<String, String>, ip_address: Option<String>, ) { - match self.oauth_server.handle_authorize(params, ip_address) { + let result = if let Some(ref oauth_server) = self.oauth_server { + oauth_server.handle_authorize(params, ip_address) + } else if let Some(ref container) = self.oauth_service { + let oauth_service = OAuthService::new(container.clone()); + oauth_service.handle_authorize(params, ip_address) + } else { + Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string()) + }; + + match result { Ok(redirect_url) => { let security_headers = self.get_security_headers(); let response = format!( @@ -227,10 +259,16 @@ impl Server { // Extract Authorization header from request let auth_header = self.extract_auth_header(request); - match self - .oauth_server - .handle_token(&form_params, auth_header.as_deref(), ip_address) - { + let result = if let Some(ref oauth_server) = self.oauth_server { + oauth_server.handle_token(&form_params, auth_header.as_deref(), ip_address) + } else if let Some(ref container) = self.oauth_service { + let oauth_service = OAuthService::new(container.clone()); + oauth_service.handle_token(&form_params, auth_header.as_deref(), ip_address) + } else { + Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string()) + }; + + match result { Ok(token_response) => { self.send_json_response(stream, 200, "OK", &token_response); } @@ -245,10 +283,16 @@ impl Server { let form_params = self.parse_form_data(&body); let auth_header = self.extract_auth_header(request); - match self - .oauth_server - .handle_token_introspection(&form_params, auth_header.as_deref()) - { + let result = if let Some(ref oauth_server) = self.oauth_server { + oauth_server.handle_token_introspection(&form_params, auth_header.as_deref()) + } else if let Some(ref container) = self.oauth_service { + let oauth_service = OAuthService::new(container.clone()); + oauth_service.handle_token_introspection(&form_params, auth_header.as_deref()) + } else { + Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string()) + }; + + match result { Ok(introspection_response) => { self.send_json_response(stream, 200, "OK", &introspection_response); } @@ -263,10 +307,16 @@ impl Server { let form_params = self.parse_form_data(&body); let auth_header = self.extract_auth_header(request); - match self - .oauth_server - .handle_token_revocation(&form_params, auth_header.as_deref()) - { + let result = if let Some(ref oauth_server) = self.oauth_server { + oauth_server.handle_token_revocation(&form_params, auth_header.as_deref()) + } else if let Some(ref container) = self.oauth_service { + let oauth_service = OAuthService::new(container.clone()); + oauth_service.handle_token_revocation(&form_params, auth_header.as_deref()) + } else { + Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string()) + }; + + match result { Ok(_) => { self.send_empty_response(stream, 200, "OK"); } @@ -1,14 +1,17 @@ pub mod clients; pub mod config; +pub mod container; pub mod database; pub mod http; pub mod keys; pub mod migrations; pub mod oauth; +pub mod repositories; +pub mod services; pub use clients::ClientManager; pub use config::Config; pub use database::Database; pub use http::Server; pub use migrations::MigrationRunner; -pub use oauth::OAuthServer; +pub use oauth::{OAuthServer, OAuthService}; diff --git a/src/main.rs b/src/main.rs index ac47a5e..4873a1d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,22 +1,36 @@ +use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; -use sts::Config; +use sts::container::ServiceContainer; use sts::http::Server; +use sts::{Config, Database}; fn main() { let config = Config::from_env(); - let server = Server::new(config.clone()).expect("Failed to create server"); + + // Initialize database + let database = Database::new(&config.database_path).expect("Failed to initialize database"); + let database = Arc::new(Mutex::new(database)); + + // Initialize service container with dependency injection + let container = ServiceContainer::new(config.clone(), database.clone()) + .expect("Failed to create service container"); + let container = Arc::new(container); + + let server = Server::new_with_container(config.clone(), container.clone()) + .expect("Failed to create server"); // Start cleanup task in background + let cleanup_container = container.clone(); let cleanup_config = config.clone(); thread::spawn(move || { loop { thread::sleep(Duration::from_secs( cleanup_config.cleanup_interval_hours as u64 * 3600, )); - // Note: In the current implementation, we don't have direct access to the OAuth server - // from here to call cleanup_expired_data(). In a production implementation, - // you'd want to structure this differently or use a background job queue. + if let Err(e) = cleanup_container.cleanup_expired_data() { + eprintln!("Cleanup task failed: {}", e); + } } }); @@ -139,11 +153,16 @@ mod tests { // Step 1: Authorization request let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("state".to_string(), "test_state".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); // Extract the authorization code from redirect URL let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); @@ -160,11 +179,13 @@ mod tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); // Parse token response - let token_response: serde_json::Value = serde_json::from_str(&token_result) - .expect("Invalid token response JSON"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); assert_eq!(token_response["token_type"], "Bearer"); assert_eq!(token_response["expires_in"], 3600); @@ -173,7 +194,8 @@ mod tests { let access_token = token_response["access_token"].as_str().unwrap(); // Step 3: Verify the JWT token has RSA signature and key ID - let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + let header = + jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256); assert!(header.kid.is_some()); assert!(!header.kid.as_ref().unwrap().is_empty()); @@ -187,10 +209,15 @@ mod tests { // Generate a token let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -204,9 +231,11 @@ mod tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); - let token_response: serde_json::Value = serde_json::from_str(&token_result) - .expect("Invalid token response JSON"); + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); // Get the JWKS @@ -214,7 +243,8 @@ mod tests { let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); // Decode the token header to get the key ID - let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + let header = + jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); let kid = header.kid.as_ref().expect("No key ID in token"); // Find the matching key in JWKS @@ -237,11 +267,16 @@ mod tests { // Generate a token through the full flow let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("scope".to_string(), "openid profile".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -255,16 +290,18 @@ mod tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); - let token_response: serde_json::Value = serde_json::from_str(&token_result) - .expect("Invalid token response JSON"); + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); // Decode the token without verification to check claims let _token_data = jsonwebtoken::decode::<serde_json::Value>( access_token, &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing - &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256) + &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256), ); // Since we can't validate with a dummy key, we'll just verify the structure @@ -275,7 +312,8 @@ mod tests { let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD .decode(parts[1]) .expect("Failed to decode payload"); - let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON"); + let claims: serde_json::Value = + serde_json::from_slice(&payload).expect("Invalid claims JSON"); assert!(claims["sub"].is_string()); assert!(claims["iss"].is_string()); @@ -293,7 +331,10 @@ mod tests { let mut params = HashMap::new(); params.insert("client_id".to_string(), "invalid_client".to_string()); - params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); params.insert("response_type".to_string(), "code".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); @@ -308,7 +349,10 @@ mod tests { let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); - params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string()); + params.insert( + "redirect_uri".to_string(), + "https://evil.com/callback".to_string(), + ); params.insert("response_type".to_string(), "code".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); @@ -323,7 +367,10 @@ mod tests { let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); - params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); params.insert("response_type".to_string(), "code".to_string()); params.insert("scope".to_string(), "invalid_scope".to_string()); @@ -340,10 +387,15 @@ mod tests { // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -371,10 +423,15 @@ mod tests { // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -402,10 +459,15 @@ mod tests { // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -422,7 +484,11 @@ mod tests { // test_client:test_secret encoded in base64 let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; - let result = oauth_server.handle_token(&token_params, Some(auth_header), Some("127.0.0.1".to_string())); + let result = oauth_server.handle_token( + &token_params, + Some(auth_header), + Some("127.0.0.1".to_string()), + ); assert!(result.is_ok()); } } diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 4b18bb3..b2d46fa 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -1,9 +1,11 @@ pub mod pkce; pub mod server; +pub mod service; pub mod types; pub use pkce::{ - generate_code_challenge, generate_code_verifier, verify_code_challenge, CodeChallengeMethod, + CodeChallengeMethod, generate_code_challenge, generate_code_verifier, verify_code_challenge, }; pub use server::OAuthServer; +pub use service::OAuthService; pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse}; diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs index 406d364..0dfc1f8 100644 --- a/src/oauth/pkce.rs +++ b/src/oauth/pkce.rs @@ -1,5 +1,5 @@ -use anyhow::{anyhow, Result}; -use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use anyhow::{Result, anyhow}; +use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; use sha2::{Digest, Sha256}; #[derive(Debug, Clone, PartialEq)] @@ -124,12 +124,14 @@ mod tests { assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err()); // Invalid characters - assert!(verify_code_challenge( - "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", - "challenge", - &CodeChallengeMethod::Plain - ) - .is_err()); + assert!( + verify_code_challenge( + "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", + "challenge", + &CodeChallengeMethod::Plain + ) + .is_err() + ); } #[test] diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 7fd8b9c..37c3cbc 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,12 +1,12 @@ -use crate::clients::{parse_basic_auth, ClientManager}; +use crate::clients::{ClientManager, parse_basic_auth}; use crate::config::Config; use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode}; use crate::keys::KeyManager; -use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod}; +use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse}; -use anyhow::{anyhow, Result}; +use anyhow::{Result, anyhow}; use chrono::{Duration, Utc}; -use jsonwebtoken::{encode, Algorithm, Header}; +use jsonwebtoken::{Algorithm, Header, encode}; use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/src/oauth/service.rs b/src/oauth/service.rs new file mode 100644 index 0000000..1b4eb49 --- /dev/null +++ b/src/oauth/service.rs @@ -0,0 +1,566 @@ +use crate::container::ServiceContainer; +use crate::database::{DbAccessToken, DbAuthCode}; +use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; +use crate::oauth::types::{ErrorResponse, TokenIntrospectionResponse, TokenResponse}; +use anyhow::Result; +use chrono::{Duration, Utc}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use std::sync::Arc; +use url::Url; +use uuid::Uuid; + +/// Refactored OAuth service using dependency injection +pub struct OAuthService { + container: Arc<ServiceContainer>, +} + +impl OAuthService { + pub fn new(container: Arc<ServiceContainer>) -> Self { + Self { container } + } + + pub fn get_jwks(&self) -> String { + self.container.get_jwks() + } + + pub fn handle_authorize( + &self, + params: &HashMap<String, String>, + ip_address: Option<String>, + ) -> Result<String, String> { + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let redirect_uri = params + .get("redirect_uri") + .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; + + let response_type = params + .get("response_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + + // Rate limiting check + if let Err(e) = self + .container + .rate_limiter + .check_rate_limit(&format!("client:{}", client_id), "/authorize") + { + let _ = self.container.audit_logger.log_event( + "authorize_rate_limited", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + + // Validate client exists + let client = match self.container.client_repository.get_client(client_id) { + Ok(Some(client)) => client, + Ok(None) => { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_client", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_client", "Invalid client_id")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + }; + + // Validate redirect URI + let redirect_uris: Vec<String> = + serde_json::from_str(&client.redirect_uris).unwrap_or_else(|_| vec![]); + + if !redirect_uris.contains(redirect_uri) { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_redirect_uri", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(redirect_uri), + ); + return Err(self.error_response("invalid_request", "Invalid redirect_uri")); + } + + // Validate requested scopes + let scope = params.get("scope").cloned(); + if let Some(ref scope_str) = scope { + let client_scopes: Vec<&str> = client.scopes.split_whitespace().collect(); + let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect(); + + for requested_scope in &requested_scopes { + if !client_scopes.contains(requested_scope) { + let _ = self.container.audit_logger.log_event( + "authorize_invalid_scope", + Some(client_id), + None, + ip_address.as_deref(), + false, + scope.as_deref(), + ); + return Err(self.error_response("invalid_scope", "Invalid scope")); + } + } + } + + if response_type != "code" { + let _ = self.container.audit_logger.log_event( + "authorize_unsupported_response_type", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(response_type), + ); + return Err(self.error_response( + "unsupported_response_type", + "Only code response type supported", + )); + } + + // PKCE validation (RFC 7636) + let code_challenge = params.get("code_challenge"); + let code_challenge_method = params + .get("code_challenge_method") + .map(|method| CodeChallengeMethod::from_str(method)) + .transpose() + .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; + + // For public clients, PKCE is required + if client.client_id.starts_with("public_") && code_challenge.is_none() { + let _ = self.container.audit_logger.log_event( + "authorize_missing_pkce", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_request", "PKCE required for public clients")); + } + + let code = Uuid::new_v4().to_string(); + let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration + + let db_auth_code = DbAuthCode { + id: 0, // Will be set by database + code: code.clone(), + client_id: client_id.clone(), + user_id: "test_user".to_string(), // In a real implementation, this would come from authentication + redirect_uri: redirect_uri.clone(), + scope: scope.clone(), + expires_at, + created_at: Utc::now(), + is_used: false, + code_challenge: code_challenge.cloned(), + code_challenge_method: code_challenge_method + .as_ref() + .map(|m| m.as_str().to_string()), + }; + + // Save to database + if let Err(_) = self + .container + .auth_code_repository + .create_auth_code(&db_auth_code) + { + return Err(self.error_response("server_error", "Failed to create authorization code")); + } + + let mut redirect_url = Url::parse(redirect_uri) + .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; + + redirect_url.query_pairs_mut().append_pair("code", &code); + + if let Some(state) = params.get("state") { + redirect_url.query_pairs_mut().append_pair("state", state); + } + + let _ = self.container.audit_logger.log_event( + "authorize_success", + Some(client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); + + Ok(redirect_url.to_string()) + } + + pub fn handle_token( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { + let grant_type = params + .get("grant_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; + + match grant_type.as_str() { + "authorization_code" => { + self.handle_authorization_code_grant(params, auth_header, ip_address) + } + "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), + _ => { + let _ = self.container.audit_logger.log_event( + "token_unsupported_grant_type", + None, + None, + ip_address.as_deref(), + false, + Some(grant_type), + ); + Err(self.error_response("unsupported_grant_type", "Unsupported grant type")) + } + } + } + + fn handle_authorization_code_grant( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { + let code = params + .get("code") + .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; + + // Client authentication using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Rate limiting check + if let Err(e) = self + .container + .rate_limiter + .check_rate_limit(&format!("client:{}", client_id), "/token") + { + let _ = self.container.audit_logger.log_event( + "token_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + + // Get and validate authorization code + let auth_code = match self.container.auth_code_repository.get_auth_code(code) { + Ok(Some(auth_code)) => auth_code, + Ok(None) => { + let _ = self.container.audit_logger.log_event( + "token_invalid_code", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(code), + ); + return Err( + self.error_response("invalid_grant", "Invalid or expired authorization code") + ); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + }; + + // Validate code hasn't been used and hasn't expired + if auth_code.is_used { + let _ = self.container.audit_logger.log_event( + "token_code_reuse", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self.error_response("invalid_grant", "Authorization code already used")); + } + + if Utc::now() > auth_code.expires_at { + let _ = self.container.audit_logger.log_event( + "token_code_expired", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self.error_response("invalid_grant", "Authorization code expired")); + } + + if auth_code.client_id != client_id { + let _ = self.container.audit_logger.log_event( + "token_client_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + // PKCE validation if code challenge was provided + if let Some(code_challenge) = &auth_code.code_challenge { + let code_verifier = params.get("code_verifier").ok_or_else(|| { + self.error_response("invalid_request", "Missing code_verifier for PKCE") + })?; + + let challenge_method = auth_code + .code_challenge_method + .as_ref() + .and_then(|method| CodeChallengeMethod::from_str(method).ok()) + .unwrap_or(CodeChallengeMethod::Plain); + + if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) + { + let _ = self.container.audit_logger.log_event( + "token_pkce_verification_failed", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "PKCE verification failed")); + } + } + + // Mark code as used + if let Err(_) = self + .container + .auth_code_repository + .mark_auth_code_used(code) + { + return Err(self.error_response("server_error", "Failed to mark code as used")); + } + + // Generate tokens using injected service + let token_id = Uuid::new_v4().to_string(); + let access_token = self.container.token_generator.generate_access_token( + &auth_code.user_id, + &client_id, + &auth_code.scope, + &token_id, + )?; + let refresh_token = self.container.token_generator.generate_refresh_token( + &client_id, + &auth_code.user_id, + &auth_code.scope, + )?; + + // Store token in database for revocation/introspection + let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); + let db_access_token = DbAccessToken { + id: 0, + token_id: token_id.clone(), + client_id: client_id.clone(), + user_id: auth_code.user_id.clone(), + scope: auth_code.scope.clone(), + expires_at: Utc::now() + Duration::hours(1), + created_at: Utc::now(), + is_revoked: false, + token_hash, + }; + + if let Err(_) = self + .container + .token_repository + .create_access_token(&db_access_token) + { + return Err(self.error_response("server_error", "Failed to store access token")); + } + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: Some(refresh_token), + scope: auth_code.scope, + }; + + let _ = self.container.audit_logger.log_event( + "token_success", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + true, + None, + ); + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + fn handle_refresh_token_grant( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { + let _refresh_token = params + .get("refresh_token") + .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; + + // Client authentication using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Validate refresh token (implementation would verify token and get user info) + // For now, return a simple refresh token response + let new_token_id = Uuid::new_v4().to_string(); + let access_token = self.container.token_generator.generate_access_token( + "test_user", + &client_id, + &None, + &new_token_id, + )?; + let new_refresh_token = self.container.token_generator.generate_refresh_token( + &client_id, + "test_user", + &None, + )?; + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: Some(new_refresh_token), + scope: None, + }; + + let _ = self.container.audit_logger.log_event( + "refresh_success", + Some(&client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + pub fn handle_token_introspection( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<String, String> { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the introspection request using injected service + let (_client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Look up token in database using repository + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + let db_token = self + .container + .token_repository + .get_access_token(&token_hash) + .ok() + .flatten(); + + let response = if let Some(db_token) = db_token { + if !db_token.is_revoked && Utc::now() < db_token.expires_at { + TokenIntrospectionResponse { + active: true, + client_id: Some(db_token.client_id.clone()), + username: Some(db_token.user_id.clone()), + scope: db_token.scope.clone(), + exp: Some(db_token.expires_at.timestamp() as u64), + iat: Some(db_token.created_at.timestamp() as u64), + sub: Some(db_token.user_id), + aud: Some(db_token.client_id), + iss: Some(self.container.config.issuer_url.clone()), + jti: Some(db_token.token_id), + } + } else { + TokenIntrospectionResponse::inactive() + } + } else { + TokenIntrospectionResponse::inactive() + }; + + serde_json::to_string(&response) + .map_err(|_| self.error_response("server_error", "Failed to serialize response")) + } + + pub fn handle_token_revocation( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<(), String> { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the revocation request using injected service + let (client_id, _client_secret) = self + .container + .client_authenticator + .authenticate(params, auth_header) + .map_err(|e| self.error_response("invalid_client", &e))?; + + // Revoke token in database using repository + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + let _ = self + .container + .token_repository + .revoke_access_token(&token_hash); // Ignore errors as per RFC 7009 + + let _ = self.container.audit_logger.log_event( + "token_revoked", + Some(&client_id), + None, + None, + true, + None, + ); + + Ok(()) + } + + fn error_response(&self, error: &str, description: &str) -> String { + let error_resp = ErrorResponse { + error: error.to_string(), + error_description: Some(description.to_string()), + error_uri: None, + }; + serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) + } + + /// Cleanup expired data using repositories + pub fn cleanup_expired_data(&self) -> Result<()> { + self.container.cleanup_expired_data() + } +} diff --git a/src/oauth/types.rs b/src/oauth/types.rs index 4f2c363..3d1c581 100644 --- a/src/oauth/types.rs +++ b/src/oauth/types.rs @@ -76,6 +76,23 @@ pub struct TokenIntrospectionResponse { pub jti: Option<String>, } +impl TokenIntrospectionResponse { + pub fn inactive() -> Self { + Self { + active: false, + client_id: None, + username: None, + scope: None, + exp: None, + iat: None, + sub: None, + aud: None, + iss: None, + jti: None, + } + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct TokenRevocationRequest { pub token: String, diff --git a/src/repositories/mod.rs b/src/repositories/mod.rs new file mode 100644 index 0000000..1685fe0 --- /dev/null +++ b/src/repositories/mod.rs @@ -0,0 +1,52 @@ +use crate::database::{DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient}; +use anyhow::Result; + +pub mod sqlite; + +pub use sqlite::{ + SqliteAuditRepository, SqliteAuthCodeRepository, SqliteClientRepository, SqliteRateRepository, + SqliteTokenRepository, +}; + +/// Repository trait for OAuth client operations +pub trait ClientRepository: Send + Sync { + fn get_client(&self, client_id: &str) -> Result<Option<DbOAuthClient>>; + fn create_client(&self, client: &DbOAuthClient) -> Result<()>; + fn update_client(&self, client: &DbOAuthClient) -> Result<()>; + fn delete_client(&self, client_id: &str) -> Result<()>; + fn list_clients(&self) -> Result<Vec<DbOAuthClient>>; +} + +/// Repository trait for authorization code operations +pub trait AuthCodeRepository: Send + Sync { + fn create_auth_code(&self, code: &DbAuthCode) -> Result<()>; + fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>>; + fn mark_auth_code_used(&self, code: &str) -> Result<()>; + fn cleanup_expired_codes(&self) -> Result<()>; +} + +/// Repository trait for access token operations +pub trait TokenRepository: Send + Sync { + fn create_access_token(&self, token: &DbAccessToken) -> Result<()>; + fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>>; + fn revoke_access_token(&self, token_hash: &str) -> Result<()>; + fn cleanup_expired_tokens(&self) -> Result<()>; +} + +/// Repository trait for audit log operations +pub trait AuditRepository: Send + Sync { + fn create_audit_log(&self, log: &DbAuditLog) -> Result<()>; + fn get_audit_logs(&self, limit: Option<i32>) -> Result<Vec<DbAuditLog>>; + fn cleanup_old_audit_logs(&self, days: i32) -> Result<()>; +} + +/// Repository trait for rate limiting operations +pub trait RateRepository: Send + Sync { + fn increment_rate_limit( + &self, + identifier: &str, + endpoint: &str, + window_size: i32, + ) -> Result<i32>; + fn cleanup_old_rate_limits(&self) -> Result<()>; +} diff --git a/src/repositories/sqlite.rs b/src/repositories/sqlite.rs new file mode 100644 index 0000000..79e6025 --- /dev/null +++ b/src/repositories/sqlite.rs @@ -0,0 +1,164 @@ +use super::*; +use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient}; +use anyhow::Result; +use std::sync::{Arc, Mutex}; + +/// SQLite implementation of ClientRepository +pub struct SqliteClientRepository { + database: Arc<Mutex<Database>>, +} + +impl SqliteClientRepository { + pub fn new(database: Arc<Mutex<Database>>) -> Self { + Self { database } + } +} + +impl ClientRepository for SqliteClientRepository { + fn get_client(&self, client_id: &str) -> Result<Option<DbOAuthClient>> { + let db = self.database.lock().unwrap(); + db.get_oauth_client(client_id) + } + + fn create_client(&self, client: &DbOAuthClient) -> Result<()> { + let db = self.database.lock().unwrap(); + db.create_oauth_client(client).map(|_| ()) + } + + fn update_client(&self, client: &DbOAuthClient) -> Result<()> { + let db = self.database.lock().unwrap(); + db.update_oauth_client(client) + } + + fn delete_client(&self, client_id: &str) -> Result<()> { + let db = self.database.lock().unwrap(); + db.delete_oauth_client(client_id) + } + + fn list_clients(&self) -> Result<Vec<DbOAuthClient>> { + let db = self.database.lock().unwrap(); + db.list_oauth_clients() + } +} + +/// SQLite implementation of AuthCodeRepository +pub struct SqliteAuthCodeRepository { + database: Arc<Mutex<Database>>, +} + +impl SqliteAuthCodeRepository { + pub fn new(database: Arc<Mutex<Database>>) -> Self { + Self { database } + } +} + +impl AuthCodeRepository for SqliteAuthCodeRepository { + fn create_auth_code(&self, code: &DbAuthCode) -> Result<()> { + let db = self.database.lock().unwrap(); + db.create_auth_code(code).map(|_| ()) + } + + fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>> { + let db = self.database.lock().unwrap(); + db.get_auth_code(code) + } + + fn mark_auth_code_used(&self, code: &str) -> Result<()> { + let db = self.database.lock().unwrap(); + db.mark_auth_code_used(code) + } + + fn cleanup_expired_codes(&self) -> Result<()> { + let db = self.database.lock().unwrap(); + db.cleanup_expired_codes().map(|_| ()) + } +} + +/// SQLite implementation of TokenRepository +pub struct SqliteTokenRepository { + database: Arc<Mutex<Database>>, +} + +impl SqliteTokenRepository { + pub fn new(database: Arc<Mutex<Database>>) -> Self { + Self { database } + } +} + +impl TokenRepository for SqliteTokenRepository { + fn create_access_token(&self, token: &DbAccessToken) -> Result<()> { + let db = self.database.lock().unwrap(); + db.create_access_token(token).map(|_| ()) + } + + fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>> { + let db = self.database.lock().unwrap(); + db.get_access_token(token_hash) + } + + fn revoke_access_token(&self, token_hash: &str) -> Result<()> { + let db = self.database.lock().unwrap(); + db.revoke_access_token(token_hash) + } + + fn cleanup_expired_tokens(&self) -> Result<()> { + let db = self.database.lock().unwrap(); + db.cleanup_expired_tokens().map(|_| ()) + } +} + +/// SQLite implementation of AuditRepository +pub struct SqliteAuditRepository { + database: Arc<Mutex<Database>>, +} + +impl SqliteAuditRepository { + pub fn new(database: Arc<Mutex<Database>>) -> Self { + Self { database } + } +} + +impl AuditRepository for SqliteAuditRepository { + fn create_audit_log(&self, log: &DbAuditLog) -> Result<()> { + let db = self.database.lock().unwrap(); + db.create_audit_log(log).map(|_| ()) + } + + fn get_audit_logs(&self, limit: Option<i32>) -> Result<Vec<DbAuditLog>> { + let db = self.database.lock().unwrap(); + db.get_audit_logs(limit.unwrap_or(100)) + } + + fn cleanup_old_audit_logs(&self, days: i32) -> Result<()> { + let db = self.database.lock().unwrap(); + db.cleanup_old_audit_logs(days).map(|_| ()) + } +} + +/// SQLite implementation of RateRepository +pub struct SqliteRateRepository { + database: Arc<Mutex<Database>>, +} + +impl SqliteRateRepository { + pub fn new(database: Arc<Mutex<Database>>) -> Self { + Self { database } + } +} + +impl RateRepository for SqliteRateRepository { + fn increment_rate_limit( + &self, + identifier: &str, + endpoint: &str, + window_size: i32, + ) -> Result<i32> { + let db = self.database.lock().unwrap(); + db.increment_rate_limit(identifier, endpoint, window_size) + } + + fn cleanup_old_rate_limits(&self) -> Result<()> { + let db = self.database.lock().unwrap(); + db.cleanup_old_rate_limits() + } +} diff --git a/src/services/implementations.rs b/src/services/implementations.rs new file mode 100644 index 0000000..ff03165 --- /dev/null +++ b/src/services/implementations.rs @@ -0,0 +1,217 @@ +use super::*; +use crate::clients::parse_basic_auth; +use crate::config::Config; +use crate::database::DbAuditLog; +use crate::keys::KeyManager; +use crate::oauth::types::Claims; +use crate::repositories::{AuditRepository, ClientRepository, RateRepository}; +use anyhow::{Result, anyhow}; +use chrono::Utc; +use jsonwebtoken::{Algorithm, Header, encode}; +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{SystemTime, UNIX_EPOCH}; +use uuid::Uuid; + +/// Default implementation of ClientAuthenticator +pub struct DefaultClientAuthenticator { + client_repository: Arc<dyn ClientRepository>, +} + +impl DefaultClientAuthenticator { + pub fn new(client_repository: Arc<dyn ClientRepository>) -> Self { + Self { client_repository } + } + + fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool { + match self.client_repository.get_client(client_id) { + Ok(Some(client)) => { + // Use constant-time comparison to prevent timing attacks + use subtle::ConstantTimeEq; + let expected_hash = client.client_secret_hash.as_bytes(); + let provided_hash = self.hash_client_secret(client_secret); + expected_hash.ct_eq(provided_hash.as_bytes()).into() + } + _ => false, + } + } + + fn hash_client_secret(&self, secret: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(secret.as_bytes()); + format!("{:x}", hasher.finalize()) + } +} + +impl ClientAuthenticator for DefaultClientAuthenticator { + fn authenticate( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<(String, String), String> { + if let Some(auth_header) = auth_header { + // HTTP Basic Authentication (preferred method) + parse_basic_auth(auth_header).ok_or_else(|| "Invalid Authorization header".to_string()) + } else { + // Form-based authentication (fallback) + let client_id = params + .get("client_id") + .ok_or_else(|| "Missing client_id".to_string())?; + let client_secret = params + .get("client_secret") + .ok_or_else(|| "Missing client_secret".to_string())?; + + if !self.authenticate_client(client_id, client_secret) { + return Err("Invalid client credentials".to_string()); + } + + Ok((client_id.clone(), client_secret.clone())) + } + } +} + +/// Default implementation of RateLimiter +pub struct DefaultRateLimiter { + rate_repository: Arc<dyn RateRepository>, + config: Config, +} + +impl DefaultRateLimiter { + pub fn new(rate_repository: Arc<dyn RateRepository>, config: Config) -> Self { + Self { + rate_repository, + config, + } + } +} + +impl RateLimiter for DefaultRateLimiter { + fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { + let count = self + .rate_repository + .increment_rate_limit(identifier, endpoint, 1)?; + + if count > self.config.rate_limit_requests_per_minute as i32 { + return Err(anyhow!("Rate limit exceeded")); + } + + Ok(()) + } +} + +/// Default implementation of AuditLogger +pub struct DefaultAuditLogger { + audit_repository: Arc<dyn AuditRepository>, + config: Config, +} + +impl DefaultAuditLogger { + pub fn new(audit_repository: Arc<dyn AuditRepository>, config: Config) -> Self { + Self { + audit_repository, + config, + } + } +} + +impl AuditLogger for DefaultAuditLogger { + fn log_event( + &self, + event_type: &str, + client_id: Option<&str>, + user_id: Option<&str>, + ip_address: Option<&str>, + success: bool, + details: Option<&str>, + ) -> Result<()> { + if !self.config.enable_audit_logging { + return Ok(()); + } + + let log = DbAuditLog { + id: 0, + event_type: event_type.to_string(), + client_id: client_id.map(|s| s.to_string()), + user_id: user_id.map(|s| s.to_string()), + ip_address: ip_address.map(|s| s.to_string()), + user_agent: None, // Could be passed in from HTTP layer + details: details.map(|s| s.to_string()), + created_at: Utc::now(), + success, + }; + + self.audit_repository.create_audit_log(&log)?; + Ok(()) + } +} + +/// Default implementation of TokenGenerator +pub struct DefaultTokenGenerator { + key_manager: Arc<Mutex<KeyManager>>, + config: Config, +} + +impl DefaultTokenGenerator { + pub fn new(key_manager: Arc<Mutex<KeyManager>>, config: Config) -> Self { + Self { + key_manager, + config, + } + } +} + +impl TokenGenerator for DefaultTokenGenerator { + fn generate_access_token( + &self, + user_id: &str, + client_id: &str, + scope: &Option<String>, + token_id: &str, + ) -> Result<String, String> { + let mut key_manager = self.key_manager.lock().unwrap(); + + // Check if we need to rotate keys + if key_manager.should_rotate() { + if let Err(_) = key_manager.rotate_keys() { + return Err("Key rotation failed".to_string()); + } + } + + let current_key = key_manager + .get_current_key() + .ok_or_else(|| "No signing key available".to_string())?; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let claims = Claims { + sub: user_id.to_string(), + iss: self.config.issuer_url.clone(), + aud: client_id.to_string(), + exp: now + 3600, + iat: now, + scope: scope.clone(), + jti: Some(token_id.to_string()), + }; + + let mut header = Header::new(Algorithm::RS256); + header.kid = Some(current_key.kid.clone()); + + encode(&header, &claims, ¤t_key.encoding_key) + .map_err(|_| "Failed to generate token".to_string()) + } + + fn generate_refresh_token( + &self, + _client_id: &str, + _user_id: &str, + _scope: &Option<String>, + ) -> Result<String, String> { + // For now, return a simple UUID-based refresh token + // In production, this should be a proper JWT or encrypted token + Ok(Uuid::new_v4().to_string()) + } +} diff --git a/src/services/mod.rs b/src/services/mod.rs new file mode 100644 index 0000000..26d74e3 --- /dev/null +++ b/src/services/mod.rs @@ -0,0 +1,49 @@ +use anyhow::Result; +use std::collections::HashMap; + +/// Service trait for client authentication +pub trait ClientAuthenticator: Send + Sync { + fn authenticate( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<(String, String), String>; // Returns (client_id, client_secret) +} + +/// Service trait for rate limiting +pub trait RateLimiter: Send + Sync { + fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()>; +} + +/// Service trait for audit logging +pub trait AuditLogger: Send + Sync { + fn log_event( + &self, + event_type: &str, + client_id: Option<&str>, + user_id: Option<&str>, + ip_address: Option<&str>, + success: bool, + details: Option<&str>, + ) -> Result<()>; +} + +/// Service trait for token generation +pub trait TokenGenerator: Send + Sync { + fn generate_access_token( + &self, + user_id: &str, + client_id: &str, + scope: &Option<String>, + token_id: &str, + ) -> Result<String, String>; + + fn generate_refresh_token( + &self, + client_id: &str, + user_id: &str, + scope: &Option<String>, + ) -> Result<String, String>; +} + +pub mod implementations; |
