diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-09 17:27:57 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-09 17:27:57 -0600 |
| commit | 50731865b1a22ab250e988aed2ea2bb8a18f9338 (patch) | |
| tree | 2155fb783438b9c2d42bbe7cf2821f46d75d3a06 | |
| parent | 72c2297eda4c18f75e7d8587773b36f3ac98b309 (diff) | |
test: add missing tests
| -rw-r--r-- | src/keys.rs | 115 | ||||
| -rw-r--r-- | src/main.rs | 176 |
2 files changed, 290 insertions, 1 deletions
diff --git a/src/keys.rs b/src/keys.rs index 6c25681..10c9697 100644 --- a/src/keys.rs +++ b/src/keys.rs @@ -21,6 +21,7 @@ pub struct KeyPair { #[derive(Debug, Serialize)] pub struct JwkKey { pub kty: String, + #[serde(rename = "use")] pub use_: String, pub kid: String, pub alg: String, @@ -152,3 +153,117 @@ impl KeyManager { }); } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_key_manager_creation() { + let manager = KeyManager::new().expect("Failed to create key manager"); + assert!(manager.get_current_key().is_some()); + assert_eq!(manager.keys.len(), 1); + } + + #[test] + fn test_key_generation() { + let mut manager = KeyManager::new().expect("Failed to create key manager"); + let initial_key_count = manager.keys.len(); + + let new_kid = manager.generate_new_key().expect("Failed to generate new key"); + assert_eq!(manager.keys.len(), initial_key_count + 1); + assert_eq!(manager.current_key_id, Some(new_kid.clone())); + assert!(manager.get_key(&new_kid).is_some()); + } + + #[test] + fn test_jwks_generation() { + let manager = KeyManager::new().expect("Failed to create key manager"); + let jwks = manager.get_jwks().expect("Failed to get JWKS"); + + assert_eq!(jwks.keys.len(), 1); + let key = &jwks.keys[0]; + assert_eq!(key.kty, "RSA"); + assert_eq!(key.use_, "sig"); + assert_eq!(key.alg, "RS256"); + assert!(!key.n.is_empty()); + assert!(!key.e.is_empty()); + assert!(!key.kid.is_empty()); + } + + #[test] + fn test_key_rotation() { + let mut manager = KeyManager::new().expect("Failed to create key manager"); + let original_kid = manager.current_key_id.clone().unwrap(); + + manager.rotate_keys().expect("Failed to rotate keys"); + + let new_kid = manager.current_key_id.clone().unwrap(); + assert_ne!(original_kid, new_kid); + assert_eq!(manager.keys.len(), 2); // Should have both old and new keys + assert!(manager.get_key(&original_kid).is_some()); + assert!(manager.get_key(&new_kid).is_some()); + } + + #[test] + fn test_should_rotate_new_key() { + let manager = KeyManager::new().expect("Failed to create key manager"); + // New key should not need rotation + assert!(!manager.should_rotate()); + } + + #[test] + fn test_should_rotate_old_key() { + let mut manager = KeyManager::new().expect("Failed to create key manager"); + + // Manually modify the current key's creation time to be old + if let Some(current_kid) = manager.current_key_id.clone() { + if let Some(key_pair) = manager.keys.get_mut(¤t_kid) { + let old_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() - 86401; // 1 day + 1 second ago + + // We need to recreate the key pair with the old timestamp + let mut old_key_pair = key_pair.clone(); + old_key_pair.created_at = old_time; + manager.keys.insert(current_kid, old_key_pair); + } + } + + // Should need rotation since key is older than rotation interval + assert!(manager.should_rotate()); + } + + #[test] + fn test_cleanup_old_keys() { + let mut manager = KeyManager::new().expect("Failed to create key manager"); + let original_kid = manager.current_key_id.clone().unwrap(); + + // Generate a new key (so we have 2 keys) + manager.rotate_keys().expect("Failed to rotate keys"); + assert_eq!(manager.keys.len(), 2); + + // Cleanup with max_age 0 should remove old keys but keep current + manager.cleanup_old_keys(0); + assert_eq!(manager.keys.len(), 1); + assert!(manager.get_key(&original_kid).is_none()); + assert!(manager.get_current_key().is_some()); + } + + #[test] + fn test_multiple_key_jwks() { + let mut manager = KeyManager::new().expect("Failed to create key manager"); + manager.rotate_keys().expect("Failed to rotate keys"); + manager.rotate_keys().expect("Failed to rotate keys"); + + let jwks = manager.get_jwks().expect("Failed to get JWKS"); + assert_eq!(jwks.keys.len(), 3); // Should have 3 keys + + // All keys should have unique key IDs + let mut kids: Vec<String> = jwks.keys.iter().map(|k| k.kid.clone()).collect(); + kids.sort(); + kids.dedup(); + assert_eq!(kids.len(), 3); // All should be unique + } +} diff --git a/src/main.rs b/src/main.rs index 25d56ea..9b6b0ea 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,8 +12,8 @@ fn main() { #[cfg(test)] mod tests { - use super::*; use std::collections::HashMap; + use base64::Engine; #[test] fn test_oauth_server_creation() { @@ -74,4 +74,178 @@ mod tests { assert!(result.is_err()); assert!(result.unwrap_err().contains("unsupported_response_type")); } + + #[test] + fn test_jwks_endpoint() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + let jwks_json = oauth_server.get_jwks(); + assert!(!jwks_json.is_empty()); + + // Parse the JSON to verify structure + let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); + assert!(jwks["keys"].is_array()); + assert!(jwks["keys"].as_array().unwrap().len() > 0); + + let key = &jwks["keys"][0]; + assert_eq!(key["kty"], "RSA"); + assert_eq!(key["use"], "sig"); + assert_eq!(key["alg"], "RS256"); + assert!(key["kid"].is_string()); + assert!(key["n"].is_string()); + assert!(key["e"].is_string()); + } + + #[test] + fn test_full_oauth_flow_with_rsa_tokens() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // Step 1: Authorization request + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert("response_type".to_string(), "code".to_string()); + auth_params.insert("state".to_string(), "test_state".to_string()); + + let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + + // Extract the authorization code from redirect URL + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + // Step 2: Token request + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + + let token_result = oauth_server.handle_token(&token_params).expect("Token request failed"); + + // Parse token response + let token_response: serde_json::Value = serde_json::from_str(&token_result) + .expect("Invalid token response JSON"); + + assert_eq!(token_response["token_type"], "Bearer"); + assert_eq!(token_response["expires_in"], 3600); + assert!(token_response["access_token"].is_string()); + + let access_token = token_response["access_token"].as_str().unwrap(); + + // Step 3: Verify the JWT token has RSA signature and key ID + let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256); + assert!(header.kid.is_some()); + assert!(!header.kid.as_ref().unwrap().is_empty()); + } + + #[test] + fn test_token_validation_with_jwks() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // Generate a token + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert("response_type".to_string(), "code".to_string()); + + let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + + let token_result = oauth_server.handle_token(&token_params).expect("Token request failed"); + let token_response: serde_json::Value = serde_json::from_str(&token_result) + .expect("Invalid token response JSON"); + let access_token = token_response["access_token"].as_str().unwrap(); + + // Get the JWKS + let jwks_json = oauth_server.get_jwks(); + let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); + + // Decode the token header to get the key ID + let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + let kid = header.kid.as_ref().expect("No key ID in token"); + + // Find the matching key in JWKS + let matching_key = jwks["keys"] + .as_array() + .unwrap() + .iter() + .find(|key| key["kid"] == *kid) + .expect("Key ID not found in JWKS"); + + assert_eq!(matching_key["kty"], "RSA"); + assert_eq!(matching_key["alg"], "RS256"); + } + + #[test] + fn test_token_contains_proper_claims() { + let config = sts::Config::from_env(); + let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); + + // Generate a token through the full flow + let mut auth_params = HashMap::new(); + auth_params.insert("client_id".to_string(), "test_client".to_string()); + auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert("response_type".to_string(), "code".to_string()); + auth_params.insert("scope".to_string(), "openid profile".to_string()); + + let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); + let auth_code = redirect_url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| value.to_string()) + .expect("No authorization code in redirect"); + + let mut token_params = HashMap::new(); + token_params.insert("grant_type".to_string(), "authorization_code".to_string()); + token_params.insert("code".to_string(), auth_code); + token_params.insert("client_id".to_string(), "test_client".to_string()); + + let token_result = oauth_server.handle_token(&token_params).expect("Token request failed"); + let token_response: serde_json::Value = serde_json::from_str(&token_result) + .expect("Invalid token response JSON"); + let access_token = token_response["access_token"].as_str().unwrap(); + + // Decode the token without verification to check claims + let _token_data = jsonwebtoken::decode::<serde_json::Value>( + access_token, + &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing + &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256) + ); + + // Since we can't validate with a dummy key, we'll just verify the structure + // by decoding the payload manually + let parts: Vec<&str> = access_token.split('.').collect(); + assert_eq!(parts.len(), 3); // header.payload.signature + + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(parts[1]) + .expect("Failed to decode payload"); + let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON"); + + assert!(claims["sub"].is_string()); + assert!(claims["iss"].is_string()); + assert!(claims["aud"].is_string()); + assert!(claims["exp"].is_number()); + assert!(claims["iat"].is_number()); + assert_eq!(claims["aud"], "test_client"); + assert_eq!(claims["scope"], "openid profile"); + } } |
