diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-09 19:39:28 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-09 19:39:28 -0600 |
| commit | 1b280d0b9be71daf9dffa0f4fa33559eda91946f (patch) | |
| tree | cc301717deabe901ce82ce469fef6be84c1b2bed /src/clients.rs | |
| parent | 50731865b1a22ab250e988aed2ea2bb8a18f9338 (diff) | |
Extract a registered client type
Diffstat (limited to 'src/clients.rs')
| -rw-r--r-- | src/clients.rs | 318 |
1 files changed, 318 insertions, 0 deletions
diff --git a/src/clients.rs b/src/clients.rs new file mode 100644 index 0000000..38d7450 --- /dev/null +++ b/src/clients.rs @@ -0,0 +1,318 @@ +use base64::Engine; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegisteredClient { + pub client_id: String, + pub client_secret_hash: String, + pub redirect_uris: Vec<String>, + pub client_name: String, + pub scopes: Vec<String>, + pub grant_types: Vec<String>, + pub response_types: Vec<String>, + pub created_at: u64, +} + +#[derive(Debug, Clone)] +pub struct ClientCredentials { + pub client_id: String, + pub client_secret: String, +} + +pub struct ClientManager { + clients: HashMap<String, RegisteredClient>, +} + +impl ClientManager { + pub fn new() -> Self { + let mut manager = Self { + clients: HashMap::new(), + }; + + // Register a default test client for development + manager.register_client( + "test_client".to_string(), + "test_secret".to_string(), + vec!["http://localhost:3000/callback".to_string()], + "Test Client".to_string(), + vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + ).ok(); + + manager + } + + pub fn register_client( + &mut self, + client_id: String, + client_secret: String, + redirect_uris: Vec<String>, + client_name: String, + scopes: Vec<String>, + ) -> Result<RegisteredClient, String> { + // Check if client_id already exists + if self.clients.contains_key(&client_id) { + return Err("Client ID already exists".to_string()); + } + + // Validate redirect URIs + for uri in &redirect_uris { + if !Self::is_valid_redirect_uri(uri) { + return Err(format!("Invalid redirect URI: {}", uri)); + } + } + + // Hash the client secret + let client_secret_hash = Self::hash_secret(&client_secret); + + let client = RegisteredClient { + client_id: client_id.clone(), + client_secret_hash, + redirect_uris, + client_name, + scopes, + grant_types: vec!["authorization_code".to_string()], + response_types: vec!["code".to_string()], + created_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + self.clients.insert(client_id, client.clone()); + Ok(client) + } + + pub fn get_client(&self, client_id: &str) -> Option<&RegisteredClient> { + self.clients.get(client_id) + } + + pub fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool { + if let Some(client) = self.get_client(client_id) { + let provided_hash = Self::hash_secret(client_secret); + // Use constant-time comparison to prevent timing attacks + self.constant_time_eq(&client.client_secret_hash, &provided_hash) + } else { + // Still perform hashing even for non-existent clients to prevent timing attacks + Self::hash_secret(client_secret); + false + } + } + + pub fn is_redirect_uri_valid(&self, client_id: &str, redirect_uri: &str) -> bool { + if let Some(client) = self.get_client(client_id) { + client.redirect_uris.contains(&redirect_uri.to_string()) + } else { + false + } + } + + pub fn is_scope_valid(&self, client_id: &str, requested_scopes: &Option<String>) -> bool { + if let Some(client) = self.get_client(client_id) { + if let Some(scopes_str) = requested_scopes { + let requested: Vec<&str> = scopes_str.split_whitespace().collect(); + requested.iter().all(|scope| client.scopes.contains(&scope.to_string())) + } else { + true // No scopes requested is valid + } + } else { + false + } + } + + fn hash_secret(secret: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(secret.as_bytes()); + format!("{:x}", hasher.finalize()) + } + + fn is_valid_redirect_uri(uri: &str) -> bool { + // Basic validation - in production, this should be more comprehensive + uri.starts_with("http://") || uri.starts_with("https://") + } + + // Constant-time string comparison to prevent timing attacks + fn constant_time_eq(&self, a: &str, b: &str) -> bool { + if a.len() != b.len() { + return false; + } + + let mut result = 0u8; + for (byte_a, byte_b) in a.bytes().zip(b.bytes()) { + result |= byte_a ^ byte_b; + } + result == 0 + } + + pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec<String>) -> Result<ClientCredentials, String> { + let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", "")); + let client_secret = Uuid::new_v4().to_string(); + + self.register_client( + client_id.clone(), + client_secret.clone(), + redirect_uris, + client_name, + vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + )?; + + Ok(ClientCredentials { + client_id, + client_secret, + }) + } + + pub fn list_clients(&self) -> Vec<&RegisteredClient> { + self.clients.values().collect() + } +} + +// HTTP Basic Auth parsing helper +pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> { + if !auth_header.starts_with("Basic ") { + return None; + } + + let encoded = &auth_header[6..]; + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .ok()?; + + let credentials = String::from_utf8(decoded).ok()?; + let mut parts = credentials.splitn(2, ':'); + + let username = parts.next()?.to_string(); + let password = parts.next()?.to_string(); + + Some((username, password)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_registration() { + let mut manager = ClientManager::new(); + + let result = manager.register_client( + "new_client".to_string(), + "secret123".to_string(), + vec!["https://example.com/callback".to_string()], + "Test App".to_string(), + vec!["openid".to_string()], + ); + + assert!(result.is_ok()); + let client = result.unwrap(); + assert_eq!(client.client_id, "new_client"); + assert_eq!(client.client_name, "Test App"); + } + + #[test] + fn test_duplicate_client_id() { + let mut manager = ClientManager::new(); + + manager.register_client( + "duplicate".to_string(), + "secret1".to_string(), + vec!["https://example.com/callback".to_string()], + "App 1".to_string(), + vec!["openid".to_string()], + ).unwrap(); + + let result = manager.register_client( + "duplicate".to_string(), + "secret2".to_string(), + vec!["https://example.com/callback".to_string()], + "App 2".to_string(), + vec!["openid".to_string()], + ); + + assert!(result.is_err()); + } + + #[test] + fn test_client_authentication() { + let mut manager = ClientManager::new(); + + manager.register_client( + "auth_client".to_string(), + "correct_secret".to_string(), + vec!["https://example.com/callback".to_string()], + "Auth Test".to_string(), + vec!["openid".to_string()], + ).unwrap(); + + assert!(manager.authenticate_client("auth_client", "correct_secret")); + assert!(!manager.authenticate_client("auth_client", "wrong_secret")); + assert!(!manager.authenticate_client("nonexistent", "any_secret")); + } + + #[test] + fn test_redirect_uri_validation() { + let mut manager = ClientManager::new(); + + manager.register_client( + "uri_client".to_string(), + "secret".to_string(), + vec!["https://app.com/callback".to_string(), "http://localhost:3000/callback".to_string()], + "URI Test".to_string(), + vec!["openid".to_string()], + ).unwrap(); + + assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback")); + assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback")); + assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback")); + assert!(!manager.is_redirect_uri_valid("nonexistent", "https://app.com/callback")); + } + + #[test] + fn test_scope_validation() { + let mut manager = ClientManager::new(); + + manager.register_client( + "scope_client".to_string(), + "secret".to_string(), + vec!["https://example.com/callback".to_string()], + "Scope Test".to_string(), + vec!["openid".to_string(), "profile".to_string()], + ).unwrap(); + + assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string()))); + assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string()))); + assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string()))); + assert!(manager.is_scope_valid("scope_client", &None)); + } + + #[test] + fn test_basic_auth_parsing() { + let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret + let result = parse_basic_auth(auth_header); + + assert!(result.is_some()); + let (username, password) = result.unwrap(); + assert_eq!(username, "test_client"); + assert_eq!(password, "test_secret"); + } + + #[test] + fn test_generate_client_credentials() { + let mut manager = ClientManager::new(); + + let result = manager.generate_client_credentials( + "Generated App".to_string(), + vec!["https://generated.com/callback".to_string()], + ); + + assert!(result.is_ok()); + let credentials = result.unwrap(); + assert!(credentials.client_id.starts_with("client_")); + assert!(!credentials.client_secret.is_empty()); + + // Verify the client was actually registered + assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret)); + } +}
\ No newline at end of file |
