diff options
Diffstat (limited to 'src/main.rs')
| -rw-r--r-- | src/main.rs | 134 |
1 files changed, 100 insertions, 34 deletions
diff --git a/src/main.rs b/src/main.rs index ac47a5e..4873a1d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,22 +1,36 @@ +use std::sync::{Arc, Mutex}; use std::thread; use std::time::Duration; -use sts::Config; +use sts::container::ServiceContainer; use sts::http::Server; +use sts::{Config, Database}; fn main() { let config = Config::from_env(); - let server = Server::new(config.clone()).expect("Failed to create server"); + + // Initialize database + let database = Database::new(&config.database_path).expect("Failed to initialize database"); + let database = Arc::new(Mutex::new(database)); + + // Initialize service container with dependency injection + let container = ServiceContainer::new(config.clone(), database.clone()) + .expect("Failed to create service container"); + let container = Arc::new(container); + + let server = Server::new_with_container(config.clone(), container.clone()) + .expect("Failed to create server"); // Start cleanup task in background + let cleanup_container = container.clone(); let cleanup_config = config.clone(); thread::spawn(move || { loop { thread::sleep(Duration::from_secs( cleanup_config.cleanup_interval_hours as u64 * 3600, )); - // Note: In the current implementation, we don't have direct access to the OAuth server - // from here to call cleanup_expired_data(). In a production implementation, - // you'd want to structure this differently or use a background job queue. + if let Err(e) = cleanup_container.cleanup_expired_data() { + eprintln!("Cleanup task failed: {}", e); + } } }); @@ -139,11 +153,16 @@ mod tests { // Step 1: Authorization request let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("state".to_string(), "test_state".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); // Extract the authorization code from redirect URL let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); @@ -160,11 +179,13 @@ mod tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); // Parse token response - let token_response: serde_json::Value = serde_json::from_str(&token_result) - .expect("Invalid token response JSON"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); assert_eq!(token_response["token_type"], "Bearer"); assert_eq!(token_response["expires_in"], 3600); @@ -173,7 +194,8 @@ mod tests { let access_token = token_response["access_token"].as_str().unwrap(); // Step 3: Verify the JWT token has RSA signature and key ID - let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + let header = + jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256); assert!(header.kid.is_some()); assert!(!header.kid.as_ref().unwrap().is_empty()); @@ -187,10 +209,15 @@ mod tests { // Generate a token let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -204,9 +231,11 @@ mod tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); - let token_response: serde_json::Value = serde_json::from_str(&token_result) - .expect("Invalid token response JSON"); + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); // Get the JWKS @@ -214,7 +243,8 @@ mod tests { let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); // Decode the token header to get the key ID - let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); + let header = + jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); let kid = header.kid.as_ref().expect("No key ID in token"); // Find the matching key in JWKS @@ -237,11 +267,16 @@ mod tests { // Generate a token through the full flow let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("scope".to_string(), "openid profile".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -255,16 +290,18 @@ mod tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); - let token_response: serde_json::Value = serde_json::from_str(&token_result) - .expect("Invalid token response JSON"); + let token_result = oauth_server + .handle_token(&token_params, None, Some("127.0.0.1".to_string())) + .expect("Token request failed"); + let token_response: serde_json::Value = + serde_json::from_str(&token_result).expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); // Decode the token without verification to check claims let _token_data = jsonwebtoken::decode::<serde_json::Value>( access_token, &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing - &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256) + &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256), ); // Since we can't validate with a dummy key, we'll just verify the structure @@ -275,7 +312,8 @@ mod tests { let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD .decode(parts[1]) .expect("Failed to decode payload"); - let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON"); + let claims: serde_json::Value = + serde_json::from_slice(&payload).expect("Invalid claims JSON"); assert!(claims["sub"].is_string()); assert!(claims["iss"].is_string()); @@ -293,7 +331,10 @@ mod tests { let mut params = HashMap::new(); params.insert("client_id".to_string(), "invalid_client".to_string()); - params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); params.insert("response_type".to_string(), "code".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); @@ -308,7 +349,10 @@ mod tests { let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); - params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string()); + params.insert( + "redirect_uri".to_string(), + "https://evil.com/callback".to_string(), + ); params.insert("response_type".to_string(), "code".to_string()); let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); @@ -323,7 +367,10 @@ mod tests { let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); - params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); params.insert("response_type".to_string(), "code".to_string()); params.insert("scope".to_string(), "invalid_scope".to_string()); @@ -340,10 +387,15 @@ mod tests { // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -371,10 +423,15 @@ mod tests { // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -402,10 +459,15 @@ mod tests { // First get an authorization code let mut auth_params = HashMap::new(); auth_params.insert("client_id".to_string(), "test_client".to_string()); - auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); + auth_params.insert( + "redirect_uri".to_string(), + "http://localhost:3000/callback".to_string(), + ); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); + let auth_result = oauth_server + .handle_authorize(&auth_params, Some("127.0.0.1".to_string())) + .expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -422,7 +484,11 @@ mod tests { // test_client:test_secret encoded in base64 let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; - let result = oauth_server.handle_token(&token_params, Some(auth_header), Some("127.0.0.1".to_string())); + let result = oauth_server.handle_token( + &token_params, + Some(auth_header), + Some("127.0.0.1".to_string()), + ); assert!(result.is_ok()); } } |
