use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; use sts::container::ServiceContainer; use sts::http::Server; use sts::{Config, Database}; fn main() { let config = Config::from_env(); // Initialize database let database = Database::new(&config.database_path).expect("Failed to initialize database"); let database = Arc::new(Mutex::new(database)); // Initialize service container with dependency injection let container = ServiceContainer::new(config.clone(), database.clone()) .expect("Failed to create service container"); let container = Arc::new(container); let server = Server::new_with_container(config.clone(), container.clone()) .expect("Failed to create server"); // Start cleanup task in background let cleanup_container = container.clone(); let cleanup_config = config.clone(); thread::spawn(move || { loop { thread::sleep(Duration::from_secs( cleanup_config.cleanup_interval_hours as u64 * 3600, )); if let Err(e) = cleanup_container.cleanup_expired_data() { eprintln!("Cleanup task failed: {}", e); } } }); println!("Starting OAuth2 STS server..."); println!("Configuration:"); println!(" Bind Address: {}", config.bind_addr); println!(" Issuer URL: {}", config.issuer_url); println!(" Database: {}", config.database_path); println!( " Rate Limit: {} requests/minute", config.rate_limit_requests_per_minute ); println!(" Audit Logging: {}", config.enable_audit_logging); server.start(); } #[cfg(test)] mod tests { use base64::Engine; use std::collections::HashMap; fn setup_test_environment() -> sts::Config { let mut config = sts::Config::from_env(); config.database_path = ":memory:".to_string(); // Use in-memory database for tests config.bind_addr = "127.0.0.1:0".to_string(); config.issuer_url = format!("http://{}", config.bind_addr); config } #[test] fn test_oauth_server_creation() { let config = setup_test_environment(); let server = sts::http::Server::new(config); assert!(server.is_ok()); } #[test] fn test_authorization_code_generation() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); params.insert("response_type".to_string(), "code".to_string()); params.insert("state".to_string(), "test_state".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_ok()); let redirect_url = result.unwrap(); assert!(redirect_url.contains("code=")); assert!(redirect_url.contains("state=test_state")); } #[test] fn test_missing_client_id() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); params.insert("response_type".to_string(), "code".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_request")); } #[test] fn test_unsupported_response_type() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); params.insert("response_type".to_string(), "token".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("unsupported_response_type")); } #[test] fn test_jwks_endpoint() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let jwks_json = oauth_server.get_jwks(); assert!(!jwks_json.is_empty()); // Parse the JSON to verify structure let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); assert!(jwks["keys"].is_array()); assert!(jwks["keys"].as_array().unwrap().len() > 0); let key = &jwks["keys"][0]; assert_eq!(key["kty"], "RSA"); assert_eq!(key["use"], "sig"); assert_eq!(key["alg"], "RS256"); assert!(key["kid"].is_string()); assert!(key["n"].is_string()); assert!(key["e"].is_string()); } #[test] fn test_full_oauth_flow_with_rsa_tokens() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // Step 1: Authorization request let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); auth_params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("state".to_string(), "test_state".to_string()); let auth_result = oauth_server .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) .expect("Authorization failed"); // Extract the authorization code from redirect URL let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() .find(|(key, _)| key == "code") .map(|(_, value)| value.to_string()) .expect("No authorization code in redirect"); // Step 2: Token request let mut token_params = HashMap::new(); token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); let token_result = oauth_server .handle_token(&token_params, None, Some("127.0.0.1".to_string())) .expect("Token request failed"); // Parse token response let token_response: serde_json::Value = serde_json::from_str(&token_result).expect("Invalid token response JSON"); assert_eq!(token_response["token_type"], "Bearer"); assert_eq!(token_response["expires_in"], 3600); assert!(token_response["access_token"].is_string()); let access_token = token_response["access_token"].as_str().unwrap(); // Step 3: Verify the JWT token has RSA signature and key ID let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256); assert!(header.kid.is_some()); assert!(!header.kid.as_ref().unwrap().is_empty()); } #[test] fn test_token_validation_with_jwks() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // Generate a token let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); auth_params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); auth_params.insert("response_type".to_string(), "code".to_string()); let auth_result = oauth_server .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() .find(|(key, _)| key == "code") .map(|(_, value)| value.to_string()) .expect("No authorization code in redirect"); let mut token_params = HashMap::new(); token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); let token_result = oauth_server .handle_token(&token_params, None, Some("127.0.0.1".to_string())) .expect("Token request failed"); let token_response: serde_json::Value = serde_json::from_str(&token_result).expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); // Get the JWKS let jwks_json = oauth_server.get_jwks(); let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); // Decode the token header to get the key ID let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); let kid = header.kid.as_ref().expect("No key ID in token"); // Find the matching key in JWKS let matching_key = jwks["keys"] .as_array() .unwrap() .iter() .find(|key| key["kid"] == *kid) .expect("Key ID not found in JWKS"); assert_eq!(matching_key["kty"], "RSA"); assert_eq!(matching_key["alg"], "RS256"); } #[test] fn test_token_contains_proper_claims() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // Generate a token through the full flow let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); auth_params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("scope".to_string(), "openid profile".to_string()); let auth_result = oauth_server .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() .find(|(key, _)| key == "code") .map(|(_, value)| value.to_string()) .expect("No authorization code in redirect"); let mut token_params = HashMap::new(); token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); let token_result = oauth_server .handle_token(&token_params, None, Some("127.0.0.1".to_string())) .expect("Token request failed"); let token_response: serde_json::Value = serde_json::from_str(&token_result).expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); // Decode the token without verification to check claims let _token_data = jsonwebtoken::decode::( access_token, &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256), ); // Since we can't validate with a dummy key, we'll just verify the structure // by decoding the payload manually let parts: Vec<&str> = access_token.split('.').collect(); assert_eq!(parts.len(), 3); // header.payload.signature let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD .decode(parts[1]) .expect("Failed to decode payload"); let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON"); assert!(claims["sub"].is_string()); assert!(claims["iss"].is_string()); assert!(claims["aud"].is_string()); assert!(claims["exp"].is_number()); assert!(claims["iat"].is_number()); assert_eq!(claims["aud"], "test_client"); assert_eq!(claims["scope"], "openid profile"); } #[test] fn test_invalid_client_id_authorization() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert("client_id".to_string(), "invalid_client".to_string()); params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); params.insert("response_type".to_string(), "code".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_client")); } #[test] fn test_invalid_redirect_uri() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); params.insert( "redirect_uri".to_string(), "https://evil.com/callback".to_string(), ); params.insert("response_type".to_string(), "code".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("Invalid redirect_uri")); } #[test] fn test_invalid_scope() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); params.insert("response_type".to_string(), "code".to_string()); params.insert("scope".to_string(), "invalid_scope".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_scope")); } #[test] fn test_invalid_client_secret() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); auth_params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); auth_params.insert("response_type".to_string(), "code".to_string()); let auth_result = oauth_server .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() .find(|(key, _)| key == "code") .map(|(_, value)| value.to_string()) .expect("No authorization code in redirect"); // Try to exchange with wrong client secret let mut token_params = HashMap::new(); token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "wrong_secret".to_string()); let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_client")); } #[test] fn test_missing_client_secret() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); auth_params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); auth_params.insert("response_type".to_string(), "code".to_string()); let auth_result = oauth_server .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() .find(|(key, _)| key == "code") .map(|(_, value)| value.to_string()) .expect("No authorization code in redirect"); // Try to exchange without client secret let mut token_params = HashMap::new(); token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); token_params.insert("client_id".to_string(), "test_client".to_string()); // Missing client_secret let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("Missing client_secret")); } #[test] fn test_http_basic_auth() { let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); auth_params.insert( "redirect_uri".to_string(), "http://localhost:3000/callback".to_string(), ); auth_params.insert("response_type".to_string(), "code".to_string()); let auth_result = oauth_server .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() .find(|(key, _)| key == "code") .map(|(_, value)| value.to_string()) .expect("No authorization code in redirect"); // Use HTTP Basic Auth instead of form parameters let mut token_params = HashMap::new(); token_params.insert("grant_type".to_string(), "authorization_code".to_string()); token_params.insert("code".to_string(), auth_code); // No client_id/client_secret in form // test_client:test_secret encoded in base64 let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; let result = oauth_server.handle_token( &token_params, Some(auth_header), Some("127.0.0.1".to_string()), ); assert!(result.is_ok()); } }