diff options
Diffstat (limited to 'src/clients.rs')
| -rw-r--r-- | src/clients.rs | 112 |
1 files changed, 69 insertions, 43 deletions
diff --git a/src/clients.rs b/src/clients.rs index bc9aea5..7941d8c 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -1,12 +1,12 @@ +use crate::database::{Database, DbOAuthClient}; +use anyhow::Result; use base64::Engine; +use chrono::Utc; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use uuid::Uuid; -use chrono::Utc; -use crate::database::{Database, DbOAuthClient}; -use anyhow::Result; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OAuthClient { @@ -37,10 +37,10 @@ impl ClientManager { clients: HashMap::new(), database: database.clone(), }; - + // Load existing clients from database into cache manager.load_clients_from_db()?; - + // Register a default test client for development if it doesn't exist if manager.get_client_from_db("test_client")?.is_none() { let _ = manager.register_client( @@ -48,13 +48,17 @@ impl ClientManager { "test_secret".to_string(), vec!["http://localhost:3000/callback".to_string()], "Test Client".to_string(), - vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + vec![ + "openid".to_string(), + "profile".to_string(), + "email".to_string(), + ], ); // Ignore errors if client already exists } - + Ok(manager) } - + fn load_clients_from_db(&mut self) -> Result<()> { // This is a simplified version - in practice you'd want to load all clients // For now we'll load on-demand @@ -129,40 +133,52 @@ impl ClientManager { if let Some(client) = self.clients.get(client_id) { return Some(client); } - + // If not in cache, try to load from database // For thread safety, we can't mutate self here, so we'll return None // In a real implementation, you'd want a more sophisticated caching strategy None } - + pub fn get_client_from_db(&mut self, client_id: &str) -> Result<Option<OAuthClient>> { // Check cache first if let Some(client) = self.clients.get(client_id) { return Ok(Some(client.clone())); } - + // Load from database let db_client = { let db = self.database.lock().unwrap(); db.get_oauth_client(client_id)? }; - + if let Some(db_client) = db_client { let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?; - let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect(); - + let scopes: Vec<String> = db_client + .scopes + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + let client = OAuthClient { client_id: db_client.client_id.clone(), client_secret_hash: db_client.client_secret_hash, redirect_uris, client_name: db_client.client_name, scopes, - grant_types: db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect(), - response_types: db_client.response_types.split_whitespace().map(|s| s.to_string()).collect(), + grant_types: db_client + .grant_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(), + response_types: db_client + .response_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(), created_at: db_client.created_at.timestamp() as u64, }; - + // Cache it self.clients.insert(db_client.client_id, client.clone()); Ok(Some(client)) @@ -181,7 +197,7 @@ impl ClientManager { return false; } }; - + let provided_hash = Self::hash_secret(client_secret); // Use constant-time comparison to prevent timing attacks self.constant_time_eq(&client.client_secret_hash, &provided_hash) @@ -199,7 +215,9 @@ impl ClientManager { if let Ok(Some(client)) = self.get_client_from_db(client_id) { if let Some(scopes_str) = requested_scopes { let requested: Vec<&str> = scopes_str.split_whitespace().collect(); - requested.iter().all(|scope| client.scopes.contains(&scope.to_string())) + requested + .iter() + .all(|scope| client.scopes.contains(&scope.to_string())) } else { true // No scopes requested is valid } @@ -224,7 +242,7 @@ impl ClientManager { if a.len() != b.len() { return false; } - + let mut result = 0u8; for (byte_a, byte_b) in a.bytes().zip(b.bytes()) { result |= byte_a ^ byte_b; @@ -232,16 +250,24 @@ impl ClientManager { result == 0 } - pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec<String>) -> Result<ClientCredentials> { + pub fn generate_client_credentials( + &mut self, + client_name: String, + redirect_uris: Vec<String>, + ) -> Result<ClientCredentials> { let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", "")); let client_secret = Uuid::new_v4().to_string(); - + self.register_client( client_id.clone(), client_secret.clone(), redirect_uris, client_name, - vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + vec![ + "openid".to_string(), + "profile".to_string(), + "email".to_string(), + ], )?; Ok(ClientCredentials { @@ -253,7 +279,7 @@ impl ClientManager { pub fn list_clients(&self) -> Vec<&OAuthClient> { self.clients.values().collect() } - + pub fn list_all_clients_from_db(&self) -> Result<Vec<DbOAuthClient>> { // This would require a new database method - for now return empty Ok(vec![]) @@ -270,13 +296,13 @@ pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> { let decoded = base64::engine::general_purpose::STANDARD .decode(encoded) .ok()?; - + let credentials = String::from_utf8(decoded).ok()?; let mut parts = credentials.splitn(2, ':'); - + let username = parts.next()?.to_string(); let password = parts.next()?.to_string(); - + Some((username, password)) } @@ -288,7 +314,7 @@ mod disabled_tests { #[test] fn test_client_registration() { let mut manager = ClientManager::new(); - + let result = manager.register_client( "new_client".to_string(), "secret123".to_string(), @@ -296,7 +322,7 @@ mod disabled_tests { "Test App".to_string(), vec!["openid".to_string()], ); - + assert!(result.is_ok()); let client = result.unwrap(); assert_eq!(client.client_id, "new_client"); @@ -306,7 +332,7 @@ mod disabled_tests { #[test] fn test_duplicate_client_id() { let mut manager = ClientManager::new(); - + manager.register_client( "duplicate".to_string(), "secret1".to_string(), @@ -314,7 +340,7 @@ mod disabled_tests { "App 1".to_string(), vec!["openid".to_string()], ).unwrap(); - + let result = manager.register_client( "duplicate".to_string(), "secret2".to_string(), @@ -322,14 +348,14 @@ mod disabled_tests { "App 2".to_string(), vec!["openid".to_string()], ); - + assert!(result.is_err()); } #[test] fn test_client_authentication() { let mut manager = ClientManager::new(); - + manager.register_client( "auth_client".to_string(), "correct_secret".to_string(), @@ -337,7 +363,7 @@ mod disabled_tests { "Auth Test".to_string(), vec!["openid".to_string()], ).unwrap(); - + assert!(manager.authenticate_client("auth_client", "correct_secret")); assert!(!manager.authenticate_client("auth_client", "wrong_secret")); assert!(!manager.authenticate_client("nonexistent", "any_secret")); @@ -346,7 +372,7 @@ mod disabled_tests { #[test] fn test_redirect_uri_validation() { let mut manager = ClientManager::new(); - + manager.register_client( "uri_client".to_string(), "secret".to_string(), @@ -354,7 +380,7 @@ mod disabled_tests { "URI Test".to_string(), vec!["openid".to_string()], ).unwrap(); - + assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback")); assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback")); assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback")); @@ -364,7 +390,7 @@ mod disabled_tests { #[test] fn test_scope_validation() { let mut manager = ClientManager::new(); - + manager.register_client( "scope_client".to_string(), "secret".to_string(), @@ -372,7 +398,7 @@ mod disabled_tests { "Scope Test".to_string(), vec!["openid".to_string(), "profile".to_string()], ).unwrap(); - + assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string()))); assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string()))); assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string()))); @@ -383,7 +409,7 @@ mod disabled_tests { fn test_basic_auth_parsing() { let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret let result = parse_basic_auth(auth_header); - + assert!(result.is_some()); let (username, password) = result.unwrap(); assert_eq!(username, "test_client"); @@ -393,19 +419,19 @@ mod disabled_tests { #[test] fn test_generate_client_credentials() { let mut manager = ClientManager::new(); - + let result = manager.generate_client_credentials( "Generated App".to_string(), vec!["https://generated.com/callback".to_string()], ); - + assert!(result.is_ok()); let credentials = result.unwrap(); assert!(credentials.client_id.starts_with("client_")); assert!(!credentials.client_secret.is_empty()); - + // Verify the client was actually registered assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret)); } } -*/
\ No newline at end of file +*/ |
