From 1b280d0b9be71daf9dffa0f4fa33559eda91946f Mon Sep 17 00:00:00 2001 From: mo khan Date: Mon, 9 Jun 2025 19:39:28 -0600 Subject: Extract a registered client type --- src/clients.rs | 318 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/http/mod.rs | 15 ++- src/lib.rs | 2 + src/main.rs | 149 +++++++++++++++++++++++- src/oauth/server.rs | 53 +++++++-- 5 files changed, 526 insertions(+), 11 deletions(-) create mode 100644 src/clients.rs (limited to 'src') diff --git a/src/clients.rs b/src/clients.rs new file mode 100644 index 0000000..38d7450 --- /dev/null +++ b/src/clients.rs @@ -0,0 +1,318 @@ +use base64::Engine; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::collections::HashMap; +use uuid::Uuid; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegisteredClient { + pub client_id: String, + pub client_secret_hash: String, + pub redirect_uris: Vec, + pub client_name: String, + pub scopes: Vec, + pub grant_types: Vec, + pub response_types: Vec, + pub created_at: u64, +} + +#[derive(Debug, Clone)] +pub struct ClientCredentials { + pub client_id: String, + pub client_secret: String, +} + +pub struct ClientManager { + clients: HashMap, +} + +impl ClientManager { + pub fn new() -> Self { + let mut manager = Self { + clients: HashMap::new(), + }; + + // Register a default test client for development + manager.register_client( + "test_client".to_string(), + "test_secret".to_string(), + vec!["http://localhost:3000/callback".to_string()], + "Test Client".to_string(), + vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + ).ok(); + + manager + } + + pub fn register_client( + &mut self, + client_id: String, + client_secret: String, + redirect_uris: Vec, + client_name: String, + scopes: Vec, + ) -> Result { + // Check if client_id already exists + if self.clients.contains_key(&client_id) { + return Err("Client ID already exists".to_string()); + } + + // Validate redirect URIs + for uri in &redirect_uris { + if !Self::is_valid_redirect_uri(uri) { + return Err(format!("Invalid redirect URI: {}", uri)); + } + } + + // Hash the client secret + let client_secret_hash = Self::hash_secret(&client_secret); + + let client = RegisteredClient { + client_id: client_id.clone(), + client_secret_hash, + redirect_uris, + client_name, + scopes, + grant_types: vec!["authorization_code".to_string()], + response_types: vec!["code".to_string()], + created_at: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + self.clients.insert(client_id, client.clone()); + Ok(client) + } + + pub fn get_client(&self, client_id: &str) -> Option<&RegisteredClient> { + self.clients.get(client_id) + } + + pub fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool { + if let Some(client) = self.get_client(client_id) { + let provided_hash = Self::hash_secret(client_secret); + // Use constant-time comparison to prevent timing attacks + self.constant_time_eq(&client.client_secret_hash, &provided_hash) + } else { + // Still perform hashing even for non-existent clients to prevent timing attacks + Self::hash_secret(client_secret); + false + } + } + + pub fn is_redirect_uri_valid(&self, client_id: &str, redirect_uri: &str) -> bool { + if let Some(client) = self.get_client(client_id) { + client.redirect_uris.contains(&redirect_uri.to_string()) + } else { + false + } + } + + pub fn is_scope_valid(&self, client_id: &str, requested_scopes: &Option) -> bool { + if let Some(client) = self.get_client(client_id) { + if let Some(scopes_str) = requested_scopes { + let requested: Vec<&str> = scopes_str.split_whitespace().collect(); + requested.iter().all(|scope| client.scopes.contains(&scope.to_string())) + } else { + true // No scopes requested is valid + } + } else { + false + } + } + + fn hash_secret(secret: &str) -> String { + let mut hasher = Sha256::new(); + hasher.update(secret.as_bytes()); + format!("{:x}", hasher.finalize()) + } + + fn is_valid_redirect_uri(uri: &str) -> bool { + // Basic validation - in production, this should be more comprehensive + uri.starts_with("http://") || uri.starts_with("https://") + } + + // Constant-time string comparison to prevent timing attacks + fn constant_time_eq(&self, a: &str, b: &str) -> bool { + if a.len() != b.len() { + return false; + } + + let mut result = 0u8; + for (byte_a, byte_b) in a.bytes().zip(b.bytes()) { + result |= byte_a ^ byte_b; + } + result == 0 + } + + pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec) -> Result { + let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", "")); + let client_secret = Uuid::new_v4().to_string(); + + self.register_client( + client_id.clone(), + client_secret.clone(), + redirect_uris, + client_name, + vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + )?; + + Ok(ClientCredentials { + client_id, + client_secret, + }) + } + + pub fn list_clients(&self) -> Vec<&RegisteredClient> { + self.clients.values().collect() + } +} + +// HTTP Basic Auth parsing helper +pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> { + if !auth_header.starts_with("Basic ") { + return None; + } + + let encoded = &auth_header[6..]; + let decoded = base64::engine::general_purpose::STANDARD + .decode(encoded) + .ok()?; + + let credentials = String::from_utf8(decoded).ok()?; + let mut parts = credentials.splitn(2, ':'); + + let username = parts.next()?.to_string(); + let password = parts.next()?.to_string(); + + Some((username, password)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_registration() { + let mut manager = ClientManager::new(); + + let result = manager.register_client( + "new_client".to_string(), + "secret123".to_string(), + vec!["https://example.com/callback".to_string()], + "Test App".to_string(), + vec!["openid".to_string()], + ); + + assert!(result.is_ok()); + let client = result.unwrap(); + assert_eq!(client.client_id, "new_client"); + assert_eq!(client.client_name, "Test App"); + } + + #[test] + fn test_duplicate_client_id() { + let mut manager = ClientManager::new(); + + manager.register_client( + "duplicate".to_string(), + "secret1".to_string(), + vec!["https://example.com/callback".to_string()], + "App 1".to_string(), + vec!["openid".to_string()], + ).unwrap(); + + let result = manager.register_client( + "duplicate".to_string(), + "secret2".to_string(), + vec!["https://example.com/callback".to_string()], + "App 2".to_string(), + vec!["openid".to_string()], + ); + + assert!(result.is_err()); + } + + #[test] + fn test_client_authentication() { + let mut manager = ClientManager::new(); + + manager.register_client( + "auth_client".to_string(), + "correct_secret".to_string(), + vec!["https://example.com/callback".to_string()], + "Auth Test".to_string(), + vec!["openid".to_string()], + ).unwrap(); + + assert!(manager.authenticate_client("auth_client", "correct_secret")); + assert!(!manager.authenticate_client("auth_client", "wrong_secret")); + assert!(!manager.authenticate_client("nonexistent", "any_secret")); + } + + #[test] + fn test_redirect_uri_validation() { + let mut manager = ClientManager::new(); + + manager.register_client( + "uri_client".to_string(), + "secret".to_string(), + vec!["https://app.com/callback".to_string(), "http://localhost:3000/callback".to_string()], + "URI Test".to_string(), + vec!["openid".to_string()], + ).unwrap(); + + assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback")); + assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback")); + assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback")); + assert!(!manager.is_redirect_uri_valid("nonexistent", "https://app.com/callback")); + } + + #[test] + fn test_scope_validation() { + let mut manager = ClientManager::new(); + + manager.register_client( + "scope_client".to_string(), + "secret".to_string(), + vec!["https://example.com/callback".to_string()], + "Scope Test".to_string(), + vec!["openid".to_string(), "profile".to_string()], + ).unwrap(); + + assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string()))); + assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string()))); + assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string()))); + assert!(manager.is_scope_valid("scope_client", &None)); + } + + #[test] + fn test_basic_auth_parsing() { + let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret + let result = parse_basic_auth(auth_header); + + assert!(result.is_some()); + let (username, password) = result.unwrap(); + assert_eq!(username, "test_client"); + assert_eq!(password, "test_secret"); + } + + #[test] + fn test_generate_client_credentials() { + let mut manager = ClientManager::new(); + + let result = manager.generate_client_credentials( + "Generated App".to_string(), + vec!["https://generated.com/callback".to_string()], + ); + + assert!(result.is_ok()); + let credentials = result.unwrap(); + assert!(credentials.client_id.starts_with("client_")); + assert!(!credentials.client_secret.is_empty()); + + // Verify the client was actually registered + assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret)); + } +} \ No newline at end of file diff --git a/src/http/mod.rs b/src/http/mod.rs index 4523d3b..7b1b983 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -166,8 +166,11 @@ impl Server { fn handle_token(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); + + // Extract Authorization header from request + let auth_header = self.extract_auth_header(request); - match self.oauth_server.handle_token(&form_params) { + match self.oauth_server.handle_token(&form_params, auth_header.as_deref()) { Ok(token_response) => { self.send_json_response(stream, 200, "OK", &token_response); } @@ -200,4 +203,14 @@ impl Server { }) .collect() } + + fn extract_auth_header(&self, request: &str) -> Option { + let lines: Vec<&str> = request.lines().collect(); + for line in lines.iter().skip(1) { // Skip the request line + if line.to_lowercase().starts_with("authorization:") { + return Some(line[14..].trim().to_string()); // Skip "Authorization: " + } + } + None + } } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 1bff27d..4ed4b7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ +pub mod clients; pub mod config; pub mod http; pub mod keys; pub mod oauth; +pub use clients::ClientManager; pub use config::Config; pub use http::Server; pub use oauth::OAuthServer; diff --git a/src/main.rs b/src/main.rs index 9b6b0ea..6b31f92 100644 --- a/src/main.rs +++ b/src/main.rs @@ -124,8 +124,9 @@ mod tests { token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params).expect("Token request failed"); + let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed"); // Parse token response let token_response: serde_json::Value = serde_json::from_str(&token_result) @@ -167,8 +168,9 @@ mod tests { token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params).expect("Token request failed"); + let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed"); let token_response: serde_json::Value = serde_json::from_str(&token_result) .expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); @@ -217,8 +219,9 @@ mod tests { token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params).expect("Token request failed"); + let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed"); let token_response: serde_json::Value = serde_json::from_str(&token_result) .expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); @@ -248,4 +251,144 @@ mod tests { assert_eq!(claims["aud"], "test_client"); assert_eq!(claims["scope"], "openid profile"); } + + #[test] + fn test_invalid_client_id_authorization() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "invalid_client".to_string()); + params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + params.insert("response_type".to_string(), "code".to_string()); + + let result = oauth_server.handle_authorize(¶ms); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid_client")); + } + + #[test] + fn test_invalid_redirect_uri() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "test_client".to_string()); + params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string()); + params.insert("response_type".to_string(), "code".to_string()); + + let result = oauth_server.handle_authorize(¶ms); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Invalid redirect_uri")); + } + + #[test] + fn test_invalid_scope() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "test_client".to_string()); + params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + params.insert("response_type".to_string(), "code".to_string()); + params.insert("scope".to_string(), "invalid_scope".to_string()); + + let result = oauth_server.handle_authorize(¶ms); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid_scope")); + } + + #[test] + fn test_invalid_client_secret() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // First get an authorization code + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Try to exchange with wrong client secret + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "wrong_secret".to_string()); + + let result = oauth_server.handle_token(&token_params, None); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid_client")); + } + + #[test] + fn test_missing_client_secret() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // First get an authorization code + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Try to exchange without client secret + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + // Missing client_secret + + let result = oauth_server.handle_token(&token_params, None); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Missing client_secret")); + } + + #[test] + fn test_http_basic_auth() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // First get an authorization code + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Use HTTP Basic Auth instead of form parameters + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + // No client_id/client_secret in form + + // test_client:test_secret encoded in base64 + let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; + + let result = oauth_server.handle_token(&token_params, Some(auth_header)); + assert!(result.is_ok()); + } } diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 888b0c2..3ee0bb3 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,3 +1,4 @@ +use crate::clients::{ClientManager, parse_basic_auth}; use crate::config::Config; use crate::keys::KeyManager; use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse}; @@ -11,15 +12,18 @@ pub struct OAuthServer { config: Config, key_manager: std::sync::Mutex, auth_codes: std::sync::Mutex>, + client_manager: std::sync::Mutex, } impl OAuthServer { pub fn new(config: &Config) -> Result> { let key_manager = KeyManager::new()?; + let client_manager = ClientManager::new(); Ok(Self { key_manager: std::sync::Mutex::new(key_manager), auth_codes: std::sync::Mutex::new(HashMap::new()), + client_manager: std::sync::Mutex::new(client_manager), config: config.clone(), }) } @@ -45,6 +49,22 @@ impl OAuthServer { .get("response_type") .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + // Validate client exists + let client_manager = self.client_manager.lock().unwrap(); + let _client = client_manager.get_client(client_id) + .ok_or_else(|| self.error_response("invalid_client", "Invalid client_id"))?; + + // Validate redirect URI is registered for this client + if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) { + return Err(self.error_response("invalid_request", "Invalid redirect_uri")); + } + + // Validate requested scopes + let scope = params.get("scope").cloned(); + if !client_manager.is_scope_valid(client_id, &scope) { + return Err(self.error_response("invalid_scope", "Invalid scope")); + } + if response_type != "code" { return Err(self.error_response( "unsupported_response_type", @@ -62,7 +82,7 @@ impl OAuthServer { let auth_code = AuthCode { client_id: client_id.clone(), redirect_uri: redirect_uri.clone(), - scope: params.get("scope").cloned(), + scope: scope, expires_at, user_id: "test_user".to_string(), }; @@ -84,7 +104,7 @@ impl OAuthServer { Ok(redirect_url.to_string()) } - pub fn handle_token(&self, params: &HashMap) -> Result { + pub fn handle_token(&self, params: &HashMap, auth_header: Option<&str>) -> Result { let grant_type = params .get("grant_type") .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; @@ -100,9 +120,28 @@ impl OAuthServer { .get("code") .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; - let client_id = params - .get("client_id") - .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + // Client authentication - RFC 6749 Section 3.2.1 + // Clients can authenticate via HTTP Basic Auth or form parameters + let (client_id, client_secret) = if let Some(auth_header) = auth_header { + // HTTP Basic Authentication (preferred method) + parse_basic_auth(auth_header) + .ok_or_else(|| self.error_response("invalid_client", "Invalid Authorization header"))? + } else { + // Form-based authentication (fallback) + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + let client_secret = params + .get("client_secret") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?; + (client_id.clone(), client_secret.clone()) + }; + + // Authenticate the client + let client_manager = self.client_manager.lock().unwrap(); + if !client_manager.authenticate_client(&client_id, &client_secret) { + return Err(self.error_response("invalid_client", "Client authentication failed")); + } let auth_code = { let mut codes = self.auth_codes.lock().unwrap(); @@ -111,7 +150,7 @@ impl OAuthServer { })? }; - if auth_code.client_id != *client_id { + if auth_code.client_id != client_id { return Err(self.error_response("invalid_grant", "Client ID mismatch")); } @@ -125,7 +164,7 @@ impl OAuthServer { } let access_token = - self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?; + self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope)?; let token_response = TokenResponse { access_token, -- cgit v1.2.3