From 9b8a098bfcfdd73bfdfcff0cb397ef2694a90367 Mon Sep 17 00:00:00 2001 From: mo khan Date: Wed, 11 Jun 2025 17:13:57 -0600 Subject: refactor: extract sts bin --- Cargo.toml | 2 +- src/bin/sts.rs | 494 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 494 --------------------------------------------------------- 3 files changed, 495 insertions(+), 495 deletions(-) create mode 100644 src/bin/sts.rs delete mode 100644 src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 039bb5f..5c2f1e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" [[bin]] name = "sts" -path = "src/main.rs" +path = "src/bin/sts.rs" [[bin]] name = "migrate" diff --git a/src/bin/sts.rs b/src/bin/sts.rs new file mode 100644 index 0000000..4873a1d --- /dev/null +++ b/src/bin/sts.rs @@ -0,0 +1,494 @@ +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::Duration; +use sts::container::ServiceContainer; +use sts::http::Server; +use sts::{Config, Database}; + +fn main() { + let config = Config::from_env(); + + // Initialize database + let database = Database::new(&config.database_path).expect("Failed to initialize database"); + let database = Arc::new(Mutex::new(database)); + + // Initialize service container with dependency injection + let container = ServiceContainer::new(config.clone(), database.clone()) + .expect("Failed to create service container"); + let container = Arc::new(container); + + let server = Server::new_with_container(config.clone(), container.clone()) + .expect("Failed to create server"); + + // Start cleanup task in background + let cleanup_container = container.clone(); + let cleanup_config = config.clone(); + thread::spawn(move || { + loop { + thread::sleep(Duration::from_secs( + cleanup_config.cleanup_interval_hours as u64 * 3600, + )); + if let Err(e) = cleanup_container.cleanup_expired_data() { + eprintln!("Cleanup task failed: {}", e); + } + } + }); + + println!("Starting OAuth2 STS server..."); + println!("Configuration:"); + println!(" Bind Address: {}", config.bind_addr); + println!(" Issuer URL: {}", config.issuer_url); + println!(" Database: {}", config.database_path); + println!( + " Rate Limit: {} requests/minute", + config.rate_limit_requests_per_minute + ); + println!(" Audit Logging: {}", config.enable_audit_logging); + + server.start(); +} + +#[cfg(test)] +mod tests { + use base64::Engine; + use std::collections::HashMap; + + fn setup_test_environment() -> sts::Config { + let mut config = sts::Config::from_env(); + config.database_path = ":memory:".to_string(); // Use in-memory database for tests + config.bind_addr = "127.0.0.1:0".to_string(); + config.issuer_url = format!("http://{}", config.bind_addr); + + config + } + + #[test] + fn test_oauth_server_creation() { + let config = setup_test_environment(); + let server = sts::http::Server::new(config); + assert!(server.is_ok()); + } + + #[test] + fn test_authorization_code_generation() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "test_client".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + params.insert("response_type".to_string(), "code".to_string()); + params.insert("state".to_string(), "test_state".to_string()); + + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); + assert!(result.is_ok()); + + let redirect_url = result.unwrap(); + assert!(redirect_url.contains("code=")); + assert!(redirect_url.contains("state=test_state")); + } + + #[test] + fn test_missing_client_id() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + let mut params = HashMap::new(); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + params.insert("response_type".to_string(), "code".to_string()); + + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid_request")); + } + + #[test] + fn test_unsupported_response_type() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "test_client".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + params.insert("response_type".to_string(), "token".to_string()); + + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("unsupported_response_type")); + } + + #[test] + fn test_jwks_endpoint() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let jwks_json = oauth_server.get_jwks(); + assert!(!jwks_json.is_empty()); + + // Parse the JSON to verify structure + let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); + assert!(jwks["keys"].is_array()); + assert!(jwks["keys"].as_array().unwrap().len() > 0); + + let key = &jwks["keys"][0]; + assert_eq!(key["kty"], "RSA"); + assert_eq!(key["use"], "sig"); + assert_eq!(key["alg"], "RS256"); + assert!(key["kid"].is_string()); + assert!(key["n"].is_string()); + assert!(key["e"].is_string()); + } + + #[test] + fn test_full_oauth_flow_with_rsa_tokens() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // Step 1: Authorization request + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + auth_params.insert("response_type".to_string(), "code".to_string()); + auth_params.insert("state".to_string(), "test_state".to_string()); + + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); + + // Extract the authorization code from redirect URL + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Step 2: Token request + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "test_secret".to_string()); + + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); + + // Parse token response + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); + + assert_eq!(token_response["token_type"], "Bearer"); + assert_eq!(token_response["expires_in"], 3600); + assert!(token_response["access_token"].is_string()); + + let access_token = token_response["access_token"].as_str().unwrap(); + + // Step 3: Verify the JWT token has RSA signature and key ID + let header = + jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256); + assert!(header.kid.is_some()); + assert!(!header.kid.as_ref().unwrap().is_empty()); + } + + #[test] + fn test_token_validation_with_jwks() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // Generate a token + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "test_secret".to_string()); + + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); + let access_token = token_response["access_token"].as_str().unwrap(); + + // Get the JWKS + let jwks_json = oauth_server.get_jwks(); + let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); + + // Decode the token header to get the key ID + let header = + jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + let kid = header.kid.as_ref().expect("No key ID in token"); + + // Find the matching key in JWKS + let matching_key = jwks["keys"] + .as_array() + .unwrap() + .iter() + .find(|key| key["kid"] == *kid) + .expect("Key ID not found in JWKS"); + + assert_eq!(matching_key["kty"], "RSA"); + assert_eq!(matching_key["alg"], "RS256"); + } + + #[test] + fn test_token_contains_proper_claims() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // Generate a token through the full flow + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + auth_params.insert("response_type".to_string(), "code".to_string()); + auth_params.insert("scope".to_string(), "openid profile".to_string()); + + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "test_secret".to_string()); + + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); + let access_token = token_response["access_token"].as_str().unwrap(); + + // Decode the token without verification to check claims + let _token_data = jsonwebtoken::decode::( + access_token, + &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing + &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256), + ); + + // Since we can't validate with a dummy key, we'll just verify the structure + // by decoding the payload manually + let parts: Vec<&str> = access_token.split('.').collect(); + assert_eq!(parts.len(), 3); // header.payload.signature + + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(parts[1]) + .expect("Failed to decode payload"); + let claims: serde_json::Value = + serde_json::from_slice(&payload).expect("Invalid claims JSON"); + + assert!(claims["sub"].is_string()); + assert!(claims["iss"].is_string()); + assert!(claims["aud"].is_string()); + assert!(claims["exp"].is_number()); + assert!(claims["iat"].is_number()); + assert_eq!(claims["aud"], "test_client"); + assert_eq!(claims["scope"], "openid profile"); + } + + #[test] + fn test_invalid_client_id_authorization() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "invalid_client".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + params.insert("response_type".to_string(), "code".to_string()); + + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid_client")); + } + + #[test] + fn test_invalid_redirect_uri() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "test_client".to_string()); + params.insert( + "redirect_uri".to_string(), + "https://evil.com/callback".to_string(), + ); + params.insert("response_type".to_string(), "code".to_string()); + + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Invalid redirect_uri")); + } + + #[test] + fn test_invalid_scope() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let mut params = HashMap::new(); + params.insert("client_id".to_string(), "test_client".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + params.insert("response_type".to_string(), "code".to_string()); + params.insert("scope".to_string(), "invalid_scope".to_string()); + + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid_scope")); + } + + #[test] + fn test_invalid_client_secret() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // First get an authorization code + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Try to exchange with wrong client secret + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + token_params.insert("client_secret".to_string(), "wrong_secret".to_string()); + + let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("invalid_client")); + } + + #[test] + fn test_missing_client_secret() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // First get an authorization code + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Try to exchange without client secret + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + // Missing client_secret + + let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Missing client_secret")); + } + + #[test] + fn test_http_basic_auth() { + let config = setup_test_environment(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // First get an authorization code + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Use HTTP Basic Auth instead of form parameters + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + // No client_id/client_secret in form + + // test_client:test_secret encoded in base64 + let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; + + let result = oauth_server.handle_token( + &token_params, + Some(auth_header), + Some("127.0.0.1".to_string()), + ); + assert!(result.is_ok()); + } +} diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 4873a1d..0000000 --- a/src/main.rs +++ /dev/null @@ -1,494 +0,0 @@ -use std::sync::{Arc, Mutex}; -use std::thread; -use std::time::Duration; -use sts::container::ServiceContainer; -use sts::http::Server; -use sts::{Config, Database}; - -fn main() { - let config = Config::from_env(); - - // Initialize database - let database = Database::new(&config.database_path).expect("Failed to initialize database"); - let database = Arc::new(Mutex::new(database)); - - // Initialize service container with dependency injection - let container = ServiceContainer::new(config.clone(), database.clone()) - .expect("Failed to create service container"); - let container = Arc::new(container); - - let server = Server::new_with_container(config.clone(), container.clone()) - .expect("Failed to create server"); - - // Start cleanup task in background - let cleanup_container = container.clone(); - let cleanup_config = config.clone(); - thread::spawn(move || { - loop { - thread::sleep(Duration::from_secs( - cleanup_config.cleanup_interval_hours as u64 * 3600, - )); - if let Err(e) = cleanup_container.cleanup_expired_data() { - eprintln!("Cleanup task failed: {}", e); - } - } - }); - - println!("Starting OAuth2 STS server..."); - println!("Configuration:"); - println!(" Bind Address: {}", config.bind_addr); - println!(" Issuer URL: {}", config.issuer_url); - println!(" Database: {}", config.database_path); - println!( - " Rate Limit: {} requests/minute", - config.rate_limit_requests_per_minute - ); - println!(" Audit Logging: {}", config.enable_audit_logging); - - server.start(); -} - -#[cfg(test)] -mod tests { - use base64::Engine; - use std::collections::HashMap; - - fn setup_test_environment() -> sts::Config { - let mut config = sts::Config::from_env(); - config.database_path = ":memory:".to_string(); // Use in-memory database for tests - config.bind_addr = "127.0.0.1:0".to_string(); - config.issuer_url = format!("http://{}", config.bind_addr); - - config - } - - #[test] - fn test_oauth_server_creation() { - let config = setup_test_environment(); - let server = sts::http::Server::new(config); - assert!(server.is_ok()); - } - - #[test] - fn test_authorization_code_generation() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - let mut params = HashMap::new(); - params.insert("client_id".to_string(), "test_client".to_string()); - params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - params.insert("response_type".to_string(), "code".to_string()); - params.insert("state".to_string(), "test_state".to_string()); - - let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); - assert!(result.is_ok()); - - let redirect_url = result.unwrap(); - assert!(redirect_url.contains("code=")); - assert!(redirect_url.contains("state=test_state")); - } - - #[test] - fn test_missing_client_id() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - let mut params = HashMap::new(); - params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - params.insert("response_type".to_string(), "code".to_string()); - - let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("invalid_request")); - } - - #[test] - fn test_unsupported_response_type() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - let mut params = HashMap::new(); - params.insert("client_id".to_string(), "test_client".to_string()); - params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - params.insert("response_type".to_string(), "token".to_string()); - - let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("unsupported_response_type")); - } - - #[test] - fn test_jwks_endpoint() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - let jwks_json = oauth_server.get_jwks(); - assert!(!jwks_json.is_empty()); - - // Parse the JSON to verify structure - let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); - assert!(jwks["keys"].is_array()); - assert!(jwks["keys"].as_array().unwrap().len() > 0); - - let key = &jwks["keys"][0]; - assert_eq!(key["kty"], "RSA"); - assert_eq!(key["use"], "sig"); - assert_eq!(key["alg"], "RS256"); - assert!(key["kid"].is_string()); - assert!(key["n"].is_string()); - assert!(key["e"].is_string()); - } - - #[test] - fn test_full_oauth_flow_with_rsa_tokens() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - // Step 1: Authorization request - let mut auth_params = HashMap::new(); - auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - auth_params.insert("response_type".to_string(), "code".to_string()); - auth_params.insert("state".to_string(), "test_state".to_string()); - - let auth_result = oauth_server - .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) - .expect("Authorization failed"); - - // Extract the authorization code from redirect URL - let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); - let auth_code = redirect_url - .query_pairs() - .find(|(key, _)| key == "code") - .map(|(_, value)| value.to_string()) - .expect("No authorization code in redirect"); - - // Step 2: Token request - let mut token_params = HashMap::new(); - token_params.insert("grant_type".to_string(), "authorization_code".to_string()); - token_params.insert("code".to_string(), auth_code); - token_params.insert("client_id".to_string(), "test_client".to_string()); - token_params.insert("client_secret".to_string(), "test_secret".to_string()); - - let token_result = oauth_server - .handle_token(&token_params, None, Some("127.0.0.1".to_string())) - .expect("Token request failed"); - - // Parse token response - let token_response: serde_json::Value = - serde_json::from_str(&token_result).expect("Invalid token response JSON"); - - assert_eq!(token_response["token_type"], "Bearer"); - assert_eq!(token_response["expires_in"], 3600); - assert!(token_response["access_token"].is_string()); - - let access_token = token_response["access_token"].as_str().unwrap(); - - // Step 3: Verify the JWT token has RSA signature and key ID - let header = - jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); - assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256); - assert!(header.kid.is_some()); - assert!(!header.kid.as_ref().unwrap().is_empty()); - } - - #[test] - fn test_token_validation_with_jwks() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - // Generate a token - let mut auth_params = HashMap::new(); - auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - auth_params.insert("response_type".to_string(), "code".to_string()); - - let auth_result = oauth_server - .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) - .expect("Authorization failed"); - let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); - let auth_code = redirect_url - .query_pairs() - .find(|(key, _)| key == "code") - .map(|(_, value)| value.to_string()) - .expect("No authorization code in redirect"); - - let mut token_params = HashMap::new(); - token_params.insert("grant_type".to_string(), "authorization_code".to_string()); - token_params.insert("code".to_string(), auth_code); - token_params.insert("client_id".to_string(), "test_client".to_string()); - token_params.insert("client_secret".to_string(), "test_secret".to_string()); - - let token_result = oauth_server - .handle_token(&token_params, None, Some("127.0.0.1".to_string())) - .expect("Token request failed"); - let token_response: serde_json::Value = - serde_json::from_str(&token_result).expect("Invalid token response JSON"); - let access_token = token_response["access_token"].as_str().unwrap(); - - // Get the JWKS - let jwks_json = oauth_server.get_jwks(); - let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); - - // Decode the token header to get the key ID - let header = - jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); - let kid = header.kid.as_ref().expect("No key ID in token"); - - // Find the matching key in JWKS - let matching_key = jwks["keys"] - .as_array() - .unwrap() - .iter() - .find(|key| key["kid"] == *kid) - .expect("Key ID not found in JWKS"); - - assert_eq!(matching_key["kty"], "RSA"); - assert_eq!(matching_key["alg"], "RS256"); - } - - #[test] - fn test_token_contains_proper_claims() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - // Generate a token through the full flow - let mut auth_params = HashMap::new(); - auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - auth_params.insert("response_type".to_string(), "code".to_string()); - auth_params.insert("scope".to_string(), "openid profile".to_string()); - - let auth_result = oauth_server - .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) - .expect("Authorization failed"); - let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); - let auth_code = redirect_url - .query_pairs() - .find(|(key, _)| key == "code") - .map(|(_, value)| value.to_string()) - .expect("No authorization code in redirect"); - - let mut token_params = HashMap::new(); - token_params.insert("grant_type".to_string(), "authorization_code".to_string()); - token_params.insert("code".to_string(), auth_code); - token_params.insert("client_id".to_string(), "test_client".to_string()); - token_params.insert("client_secret".to_string(), "test_secret".to_string()); - - let token_result = oauth_server - .handle_token(&token_params, None, Some("127.0.0.1".to_string())) - .expect("Token request failed"); - let token_response: serde_json::Value = - serde_json::from_str(&token_result).expect("Invalid token response JSON"); - let access_token = token_response["access_token"].as_str().unwrap(); - - // Decode the token without verification to check claims - let _token_data = jsonwebtoken::decode::( - access_token, - &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing - &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256), - ); - - // Since we can't validate with a dummy key, we'll just verify the structure - // by decoding the payload manually - let parts: Vec<&str> = access_token.split('.').collect(); - assert_eq!(parts.len(), 3); // header.payload.signature - - let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode(parts[1]) - .expect("Failed to decode payload"); - let claims: serde_json::Value = - serde_json::from_slice(&payload).expect("Invalid claims JSON"); - - assert!(claims["sub"].is_string()); - assert!(claims["iss"].is_string()); - assert!(claims["aud"].is_string()); - assert!(claims["exp"].is_number()); - assert!(claims["iat"].is_number()); - assert_eq!(claims["aud"], "test_client"); - assert_eq!(claims["scope"], "openid profile"); - } - - #[test] - fn test_invalid_client_id_authorization() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - let mut params = HashMap::new(); - params.insert("client_id".to_string(), "invalid_client".to_string()); - params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - params.insert("response_type".to_string(), "code".to_string()); - - let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("invalid_client")); - } - - #[test] - fn test_invalid_redirect_uri() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - let mut params = HashMap::new(); - params.insert("client_id".to_string(), "test_client".to_string()); - params.insert( - "redirect_uri".to_string(), - "https://evil.com/callback".to_string(), - ); - params.insert("response_type".to_string(), "code".to_string()); - - let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("Invalid redirect_uri")); - } - - #[test] - fn test_invalid_scope() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - let mut params = HashMap::new(); - params.insert("client_id".to_string(), "test_client".to_string()); - params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - params.insert("response_type".to_string(), "code".to_string()); - params.insert("scope".to_string(), "invalid_scope".to_string()); - - let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("invalid_scope")); - } - - #[test] - fn test_invalid_client_secret() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - // First get an authorization code - let mut auth_params = HashMap::new(); - auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - auth_params.insert("response_type".to_string(), "code".to_string()); - - let auth_result = oauth_server - .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) - .expect("Authorization failed"); - let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); - let auth_code = redirect_url - .query_pairs() - .find(|(key, _)| key == "code") - .map(|(_, value)| value.to_string()) - .expect("No authorization code in redirect"); - - // Try to exchange with wrong client secret - let mut token_params = HashMap::new(); - token_params.insert("grant_type".to_string(), "authorization_code".to_string()); - token_params.insert("code".to_string(), auth_code); - token_params.insert("client_id".to_string(), "test_client".to_string()); - token_params.insert("client_secret".to_string(), "wrong_secret".to_string()); - - let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("invalid_client")); - } - - #[test] - fn test_missing_client_secret() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - // First get an authorization code - let mut auth_params = HashMap::new(); - auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - auth_params.insert("response_type".to_string(), "code".to_string()); - - let auth_result = oauth_server - .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) - .expect("Authorization failed"); - let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); - let auth_code = redirect_url - .query_pairs() - .find(|(key, _)| key == "code") - .map(|(_, value)| value.to_string()) - .expect("No authorization code in redirect"); - - // Try to exchange without client secret - let mut token_params = HashMap::new(); - token_params.insert("grant_type".to_string(), "authorization_code".to_string()); - token_params.insert("code".to_string(), auth_code); - token_params.insert("client_id".to_string(), "test_client".to_string()); - // Missing client_secret - - let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("Missing client_secret")); - } - - #[test] - fn test_http_basic_auth() { - let config = setup_test_environment(); - let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - - // First get an authorization code - let mut auth_params = HashMap::new(); - auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert( - "redirect_uri".to_string(), - "http://localhost:3000/callback".to_string(), - ); - auth_params.insert("response_type".to_string(), "code".to_string()); - - let auth_result = oauth_server - .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) - .expect("Authorization failed"); - let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); - let auth_code = redirect_url - .query_pairs() - .find(|(key, _)| key == "code") - .map(|(_, value)| value.to_string()) - .expect("No authorization code in redirect"); - - // Use HTTP Basic Auth instead of form parameters - let mut token_params = HashMap::new(); - token_params.insert("grant_type".to_string(), "authorization_code".to_string()); - token_params.insert("code".to_string(), auth_code); - // No client_id/client_secret in form - - // test_client:test_secret encoded in base64 - let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; - - let result = oauth_server.handle_token( - &token_params, - Some(auth_header), - Some("127.0.0.1".to_string()), - ); - assert!(result.is_ok()); - } -} -- cgit v1.2.3