summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-09 17:27:57 -0600
committermo khan <mo@mokhan.ca>2025-06-09 17:27:57 -0600
commit50731865b1a22ab250e988aed2ea2bb8a18f9338 (patch)
tree2155fb783438b9c2d42bbe7cf2821f46d75d3a06
parent72c2297eda4c18f75e7d8587773b36f3ac98b309 (diff)
test: add missing tests
-rw-r--r--src/keys.rs115
-rw-r--r--src/main.rs176
2 files changed, 290 insertions, 1 deletions
diff --git a/src/keys.rs b/src/keys.rs
index 6c25681..10c9697 100644
--- a/src/keys.rs
+++ b/src/keys.rs
@@ -21,6 +21,7 @@ pub struct KeyPair {
#[derive(Debug, Serialize)]
pub struct JwkKey {
pub kty: String,
+ #[serde(rename = "use")]
pub use_: String,
pub kid: String,
pub alg: String,
@@ -152,3 +153,117 @@ impl KeyManager {
});
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_key_manager_creation() {
+ let manager = KeyManager::new().expect("Failed to create key manager");
+ assert!(manager.get_current_key().is_some());
+ assert_eq!(manager.keys.len(), 1);
+ }
+
+ #[test]
+ fn test_key_generation() {
+ let mut manager = KeyManager::new().expect("Failed to create key manager");
+ let initial_key_count = manager.keys.len();
+
+ let new_kid = manager.generate_new_key().expect("Failed to generate new key");
+ assert_eq!(manager.keys.len(), initial_key_count + 1);
+ assert_eq!(manager.current_key_id, Some(new_kid.clone()));
+ assert!(manager.get_key(&new_kid).is_some());
+ }
+
+ #[test]
+ fn test_jwks_generation() {
+ let manager = KeyManager::new().expect("Failed to create key manager");
+ let jwks = manager.get_jwks().expect("Failed to get JWKS");
+
+ assert_eq!(jwks.keys.len(), 1);
+ let key = &jwks.keys[0];
+ assert_eq!(key.kty, "RSA");
+ assert_eq!(key.use_, "sig");
+ assert_eq!(key.alg, "RS256");
+ assert!(!key.n.is_empty());
+ assert!(!key.e.is_empty());
+ assert!(!key.kid.is_empty());
+ }
+
+ #[test]
+ fn test_key_rotation() {
+ let mut manager = KeyManager::new().expect("Failed to create key manager");
+ let original_kid = manager.current_key_id.clone().unwrap();
+
+ manager.rotate_keys().expect("Failed to rotate keys");
+
+ let new_kid = manager.current_key_id.clone().unwrap();
+ assert_ne!(original_kid, new_kid);
+ assert_eq!(manager.keys.len(), 2); // Should have both old and new keys
+ assert!(manager.get_key(&original_kid).is_some());
+ assert!(manager.get_key(&new_kid).is_some());
+ }
+
+ #[test]
+ fn test_should_rotate_new_key() {
+ let manager = KeyManager::new().expect("Failed to create key manager");
+ // New key should not need rotation
+ assert!(!manager.should_rotate());
+ }
+
+ #[test]
+ fn test_should_rotate_old_key() {
+ let mut manager = KeyManager::new().expect("Failed to create key manager");
+
+ // Manually modify the current key's creation time to be old
+ if let Some(current_kid) = manager.current_key_id.clone() {
+ if let Some(key_pair) = manager.keys.get_mut(&current_kid) {
+ let old_time = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs() - 86401; // 1 day + 1 second ago
+
+ // We need to recreate the key pair with the old timestamp
+ let mut old_key_pair = key_pair.clone();
+ old_key_pair.created_at = old_time;
+ manager.keys.insert(current_kid, old_key_pair);
+ }
+ }
+
+ // Should need rotation since key is older than rotation interval
+ assert!(manager.should_rotate());
+ }
+
+ #[test]
+ fn test_cleanup_old_keys() {
+ let mut manager = KeyManager::new().expect("Failed to create key manager");
+ let original_kid = manager.current_key_id.clone().unwrap();
+
+ // Generate a new key (so we have 2 keys)
+ manager.rotate_keys().expect("Failed to rotate keys");
+ assert_eq!(manager.keys.len(), 2);
+
+ // Cleanup with max_age 0 should remove old keys but keep current
+ manager.cleanup_old_keys(0);
+ assert_eq!(manager.keys.len(), 1);
+ assert!(manager.get_key(&original_kid).is_none());
+ assert!(manager.get_current_key().is_some());
+ }
+
+ #[test]
+ fn test_multiple_key_jwks() {
+ let mut manager = KeyManager::new().expect("Failed to create key manager");
+ manager.rotate_keys().expect("Failed to rotate keys");
+ manager.rotate_keys().expect("Failed to rotate keys");
+
+ let jwks = manager.get_jwks().expect("Failed to get JWKS");
+ assert_eq!(jwks.keys.len(), 3); // Should have 3 keys
+
+ // All keys should have unique key IDs
+ let mut kids: Vec<String> = jwks.keys.iter().map(|k| k.kid.clone()).collect();
+ kids.sort();
+ kids.dedup();
+ assert_eq!(kids.len(), 3); // All should be unique
+ }
+}
diff --git a/src/main.rs b/src/main.rs
index 25d56ea..9b6b0ea 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -12,8 +12,8 @@ fn main() {
#[cfg(test)]
mod tests {
- use super::*;
use std::collections::HashMap;
+ use base64::Engine;
#[test]
fn test_oauth_server_creation() {
@@ -74,4 +74,178 @@ mod tests {
assert!(result.is_err());
assert!(result.unwrap_err().contains("unsupported_response_type"));
}
+
+ #[test]
+ fn test_jwks_endpoint() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ let jwks_json = oauth_server.get_jwks();
+ assert!(!jwks_json.is_empty());
+
+ // Parse the JSON to verify structure
+ let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
+ assert!(jwks["keys"].is_array());
+ assert!(jwks["keys"].as_array().unwrap().len() > 0);
+
+ let key = &jwks["keys"][0];
+ assert_eq!(key["kty"], "RSA");
+ assert_eq!(key["use"], "sig");
+ assert_eq!(key["alg"], "RS256");
+ assert!(key["kid"].is_string());
+ assert!(key["n"].is_string());
+ assert!(key["e"].is_string());
+ }
+
+ #[test]
+ fn test_full_oauth_flow_with_rsa_tokens() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // Step 1: Authorization request
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+ auth_params.insert("state".to_string(), "test_state".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+
+ // Extract the authorization code from redirect URL
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ // Step 2: Token request
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ token_params.insert("client_id".to_string(), "test_client".to_string());
+
+ let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+
+ // Parse token response
+ let token_response: serde_json::Value = serde_json::from_str(&token_result)
+ .expect("Invalid token response JSON");
+
+ assert_eq!(token_response["token_type"], "Bearer");
+ assert_eq!(token_response["expires_in"], 3600);
+ assert!(token_response["access_token"].is_string());
+
+ let access_token = token_response["access_token"].as_str().unwrap();
+
+ // Step 3: Verify the JWT token has RSA signature and key ID
+ let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
+ assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
+ assert!(header.kid.is_some());
+ assert!(!header.kid.as_ref().unwrap().is_empty());
+ }
+
+ #[test]
+ fn test_token_validation_with_jwks() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // Generate a token
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ token_params.insert("client_id".to_string(), "test_client".to_string());
+
+ let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_response: serde_json::Value = serde_json::from_str(&token_result)
+ .expect("Invalid token response JSON");
+ let access_token = token_response["access_token"].as_str().unwrap();
+
+ // Get the JWKS
+ let jwks_json = oauth_server.get_jwks();
+ let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
+
+ // Decode the token header to get the key ID
+ let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
+ let kid = header.kid.as_ref().expect("No key ID in token");
+
+ // Find the matching key in JWKS
+ let matching_key = jwks["keys"]
+ .as_array()
+ .unwrap()
+ .iter()
+ .find(|key| key["kid"] == *kid)
+ .expect("Key ID not found in JWKS");
+
+ assert_eq!(matching_key["kty"], "RSA");
+ assert_eq!(matching_key["alg"], "RS256");
+ }
+
+ #[test]
+ fn test_token_contains_proper_claims() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // Generate a token through the full flow
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+ auth_params.insert("scope".to_string(), "openid profile".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ token_params.insert("client_id".to_string(), "test_client".to_string());
+
+ let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_response: serde_json::Value = serde_json::from_str(&token_result)
+ .expect("Invalid token response JSON");
+ let access_token = token_response["access_token"].as_str().unwrap();
+
+ // Decode the token without verification to check claims
+ let _token_data = jsonwebtoken::decode::<serde_json::Value>(
+ access_token,
+ &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing
+ &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256)
+ );
+
+ // Since we can't validate with a dummy key, we'll just verify the structure
+ // by decoding the payload manually
+ let parts: Vec<&str> = access_token.split('.').collect();
+ assert_eq!(parts.len(), 3); // header.payload.signature
+
+ let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
+ .decode(parts[1])
+ .expect("Failed to decode payload");
+ let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON");
+
+ assert!(claims["sub"].is_string());
+ assert!(claims["iss"].is_string());
+ assert!(claims["aud"].is_string());
+ assert!(claims["exp"].is_number());
+ assert!(claims["iat"].is_number());
+ assert_eq!(claims["aud"], "test_client");
+ assert_eq!(claims["scope"], "openid profile");
+ }
}