summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 17:13:57 -0600
committermo khan <mo@mokhan.ca>2025-06-11 17:13:57 -0600
commit9b8a098bfcfdd73bfdfcff0cb397ef2694a90367 (patch)
treea0e89e827efdc8e3105d2f3033bbde987d65a28f /src/main.rs
parent5ffc9b007ccbd8a4510b58de72aaee53291d7973 (diff)
refactor: extract sts bin
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs494
1 files changed, 0 insertions, 494 deletions
diff --git a/src/main.rs b/src/main.rs
deleted file mode 100644
index 4873a1d..0000000
--- a/src/main.rs
+++ /dev/null
@@ -1,494 +0,0 @@
-use std::sync::{Arc, Mutex};
-use std::thread;
-use std::time::Duration;
-use sts::container::ServiceContainer;
-use sts::http::Server;
-use sts::{Config, Database};
-
-fn main() {
- let config = Config::from_env();
-
- // Initialize database
- let database = Database::new(&config.database_path).expect("Failed to initialize database");
- let database = Arc::new(Mutex::new(database));
-
- // Initialize service container with dependency injection
- let container = ServiceContainer::new(config.clone(), database.clone())
- .expect("Failed to create service container");
- let container = Arc::new(container);
-
- let server = Server::new_with_container(config.clone(), container.clone())
- .expect("Failed to create server");
-
- // Start cleanup task in background
- let cleanup_container = container.clone();
- let cleanup_config = config.clone();
- thread::spawn(move || {
- loop {
- thread::sleep(Duration::from_secs(
- cleanup_config.cleanup_interval_hours as u64 * 3600,
- ));
- if let Err(e) = cleanup_container.cleanup_expired_data() {
- eprintln!("Cleanup task failed: {}", e);
- }
- }
- });
-
- println!("Starting OAuth2 STS server...");
- println!("Configuration:");
- println!(" Bind Address: {}", config.bind_addr);
- println!(" Issuer URL: {}", config.issuer_url);
- println!(" Database: {}", config.database_path);
- println!(
- " Rate Limit: {} requests/minute",
- config.rate_limit_requests_per_minute
- );
- println!(" Audit Logging: {}", config.enable_audit_logging);
-
- server.start();
-}
-
-#[cfg(test)]
-mod tests {
- use base64::Engine;
- use std::collections::HashMap;
-
- fn setup_test_environment() -> sts::Config {
- let mut config = sts::Config::from_env();
- config.database_path = ":memory:".to_string(); // Use in-memory database for tests
- config.bind_addr = "127.0.0.1:0".to_string();
- config.issuer_url = format!("http://{}", config.bind_addr);
-
- config
- }
-
- #[test]
- fn test_oauth_server_creation() {
- let config = setup_test_environment();
- let server = sts::http::Server::new(config);
- assert!(server.is_ok());
- }
-
- #[test]
- fn test_authorization_code_generation() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
- let mut params = HashMap::new();
- params.insert("client_id".to_string(), "test_client".to_string());
- params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- params.insert("response_type".to_string(), "code".to_string());
- params.insert("state".to_string(), "test_state".to_string());
-
- let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
- assert!(result.is_ok());
-
- let redirect_url = result.unwrap();
- assert!(redirect_url.contains("code="));
- assert!(redirect_url.contains("state=test_state"));
- }
-
- #[test]
- fn test_missing_client_id() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
- let mut params = HashMap::new();
- params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- params.insert("response_type".to_string(), "code".to_string());
-
- let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
- assert!(result.is_err());
- assert!(result.unwrap_err().contains("invalid_request"));
- }
-
- #[test]
- fn test_unsupported_response_type() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
- let mut params = HashMap::new();
- params.insert("client_id".to_string(), "test_client".to_string());
- params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- params.insert("response_type".to_string(), "token".to_string());
-
- let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
- assert!(result.is_err());
- assert!(result.unwrap_err().contains("unsupported_response_type"));
- }
-
- #[test]
- fn test_jwks_endpoint() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- let jwks_json = oauth_server.get_jwks();
- assert!(!jwks_json.is_empty());
-
- // Parse the JSON to verify structure
- let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
- assert!(jwks["keys"].is_array());
- assert!(jwks["keys"].as_array().unwrap().len() > 0);
-
- let key = &jwks["keys"][0];
- assert_eq!(key["kty"], "RSA");
- assert_eq!(key["use"], "sig");
- assert_eq!(key["alg"], "RS256");
- assert!(key["kid"].is_string());
- assert!(key["n"].is_string());
- assert!(key["e"].is_string());
- }
-
- #[test]
- fn test_full_oauth_flow_with_rsa_tokens() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- // Step 1: Authorization request
- let mut auth_params = HashMap::new();
- auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- auth_params.insert("response_type".to_string(), "code".to_string());
- auth_params.insert("state".to_string(), "test_state".to_string());
-
- let auth_result = oauth_server
- .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
- .expect("Authorization failed");
-
- // Extract the authorization code from redirect URL
- let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
- let auth_code = redirect_url
- .query_pairs()
- .find(|(key, _)| key == "code")
- .map(|(_, value)| value.to_string())
- .expect("No authorization code in redirect");
-
- // Step 2: Token request
- let mut token_params = HashMap::new();
- token_params.insert("grant_type".to_string(), "authorization_code".to_string());
- token_params.insert("code".to_string(), auth_code);
- token_params.insert("client_id".to_string(), "test_client".to_string());
- token_params.insert("client_secret".to_string(), "test_secret".to_string());
-
- let token_result = oauth_server
- .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
- .expect("Token request failed");
-
- // Parse token response
- let token_response: serde_json::Value =
- serde_json::from_str(&token_result).expect("Invalid token response JSON");
-
- assert_eq!(token_response["token_type"], "Bearer");
- assert_eq!(token_response["expires_in"], 3600);
- assert!(token_response["access_token"].is_string());
-
- let access_token = token_response["access_token"].as_str().unwrap();
-
- // Step 3: Verify the JWT token has RSA signature and key ID
- let header =
- jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
- assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
- assert!(header.kid.is_some());
- assert!(!header.kid.as_ref().unwrap().is_empty());
- }
-
- #[test]
- fn test_token_validation_with_jwks() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- // Generate a token
- let mut auth_params = HashMap::new();
- auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- auth_params.insert("response_type".to_string(), "code".to_string());
-
- let auth_result = oauth_server
- .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
- .expect("Authorization failed");
- let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
- let auth_code = redirect_url
- .query_pairs()
- .find(|(key, _)| key == "code")
- .map(|(_, value)| value.to_string())
- .expect("No authorization code in redirect");
-
- let mut token_params = HashMap::new();
- token_params.insert("grant_type".to_string(), "authorization_code".to_string());
- token_params.insert("code".to_string(), auth_code);
- token_params.insert("client_id".to_string(), "test_client".to_string());
- token_params.insert("client_secret".to_string(), "test_secret".to_string());
-
- let token_result = oauth_server
- .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
- .expect("Token request failed");
- let token_response: serde_json::Value =
- serde_json::from_str(&token_result).expect("Invalid token response JSON");
- let access_token = token_response["access_token"].as_str().unwrap();
-
- // Get the JWKS
- let jwks_json = oauth_server.get_jwks();
- let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
-
- // Decode the token header to get the key ID
- let header =
- jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
- let kid = header.kid.as_ref().expect("No key ID in token");
-
- // Find the matching key in JWKS
- let matching_key = jwks["keys"]
- .as_array()
- .unwrap()
- .iter()
- .find(|key| key["kid"] == *kid)
- .expect("Key ID not found in JWKS");
-
- assert_eq!(matching_key["kty"], "RSA");
- assert_eq!(matching_key["alg"], "RS256");
- }
-
- #[test]
- fn test_token_contains_proper_claims() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- // Generate a token through the full flow
- let mut auth_params = HashMap::new();
- auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- auth_params.insert("response_type".to_string(), "code".to_string());
- auth_params.insert("scope".to_string(), "openid profile".to_string());
-
- let auth_result = oauth_server
- .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
- .expect("Authorization failed");
- let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
- let auth_code = redirect_url
- .query_pairs()
- .find(|(key, _)| key == "code")
- .map(|(_, value)| value.to_string())
- .expect("No authorization code in redirect");
-
- let mut token_params = HashMap::new();
- token_params.insert("grant_type".to_string(), "authorization_code".to_string());
- token_params.insert("code".to_string(), auth_code);
- token_params.insert("client_id".to_string(), "test_client".to_string());
- token_params.insert("client_secret".to_string(), "test_secret".to_string());
-
- let token_result = oauth_server
- .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
- .expect("Token request failed");
- let token_response: serde_json::Value =
- serde_json::from_str(&token_result).expect("Invalid token response JSON");
- let access_token = token_response["access_token"].as_str().unwrap();
-
- // Decode the token without verification to check claims
- let _token_data = jsonwebtoken::decode::<serde_json::Value>(
- access_token,
- &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing
- &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256),
- );
-
- // Since we can't validate with a dummy key, we'll just verify the structure
- // by decoding the payload manually
- let parts: Vec<&str> = access_token.split('.').collect();
- assert_eq!(parts.len(), 3); // header.payload.signature
-
- let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
- .decode(parts[1])
- .expect("Failed to decode payload");
- let claims: serde_json::Value =
- serde_json::from_slice(&payload).expect("Invalid claims JSON");
-
- assert!(claims["sub"].is_string());
- assert!(claims["iss"].is_string());
- assert!(claims["aud"].is_string());
- assert!(claims["exp"].is_number());
- assert!(claims["iat"].is_number());
- assert_eq!(claims["aud"], "test_client");
- assert_eq!(claims["scope"], "openid profile");
- }
-
- #[test]
- fn test_invalid_client_id_authorization() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- let mut params = HashMap::new();
- params.insert("client_id".to_string(), "invalid_client".to_string());
- params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- params.insert("response_type".to_string(), "code".to_string());
-
- let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
- assert!(result.is_err());
- assert!(result.unwrap_err().contains("invalid_client"));
- }
-
- #[test]
- fn test_invalid_redirect_uri() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- let mut params = HashMap::new();
- params.insert("client_id".to_string(), "test_client".to_string());
- params.insert(
- "redirect_uri".to_string(),
- "https://evil.com/callback".to_string(),
- );
- params.insert("response_type".to_string(), "code".to_string());
-
- let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
- assert!(result.is_err());
- assert!(result.unwrap_err().contains("Invalid redirect_uri"));
- }
-
- #[test]
- fn test_invalid_scope() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- let mut params = HashMap::new();
- params.insert("client_id".to_string(), "test_client".to_string());
- params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- params.insert("response_type".to_string(), "code".to_string());
- params.insert("scope".to_string(), "invalid_scope".to_string());
-
- let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
- assert!(result.is_err());
- assert!(result.unwrap_err().contains("invalid_scope"));
- }
-
- #[test]
- fn test_invalid_client_secret() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- // First get an authorization code
- let mut auth_params = HashMap::new();
- auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- auth_params.insert("response_type".to_string(), "code".to_string());
-
- let auth_result = oauth_server
- .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
- .expect("Authorization failed");
- let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
- let auth_code = redirect_url
- .query_pairs()
- .find(|(key, _)| key == "code")
- .map(|(_, value)| value.to_string())
- .expect("No authorization code in redirect");
-
- // Try to exchange with wrong client secret
- let mut token_params = HashMap::new();
- token_params.insert("grant_type".to_string(), "authorization_code".to_string());
- token_params.insert("code".to_string(), auth_code);
- token_params.insert("client_id".to_string(), "test_client".to_string());
- token_params.insert("client_secret".to_string(), "wrong_secret".to_string());
-
- let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string()));
- assert!(result.is_err());
- assert!(result.unwrap_err().contains("invalid_client"));
- }
-
- #[test]
- fn test_missing_client_secret() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- // First get an authorization code
- let mut auth_params = HashMap::new();
- auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- auth_params.insert("response_type".to_string(), "code".to_string());
-
- let auth_result = oauth_server
- .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
- .expect("Authorization failed");
- let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
- let auth_code = redirect_url
- .query_pairs()
- .find(|(key, _)| key == "code")
- .map(|(_, value)| value.to_string())
- .expect("No authorization code in redirect");
-
- // Try to exchange without client secret
- let mut token_params = HashMap::new();
- token_params.insert("grant_type".to_string(), "authorization_code".to_string());
- token_params.insert("code".to_string(), auth_code);
- token_params.insert("client_id".to_string(), "test_client".to_string());
- // Missing client_secret
-
- let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string()));
- assert!(result.is_err());
- assert!(result.unwrap_err().contains("Missing client_secret"));
- }
-
- #[test]
- fn test_http_basic_auth() {
- let config = setup_test_environment();
- let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
- // First get an authorization code
- let mut auth_params = HashMap::new();
- auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert(
- "redirect_uri".to_string(),
- "http://localhost:3000/callback".to_string(),
- );
- auth_params.insert("response_type".to_string(), "code".to_string());
-
- let auth_result = oauth_server
- .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
- .expect("Authorization failed");
- let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
- let auth_code = redirect_url
- .query_pairs()
- .find(|(key, _)| key == "code")
- .map(|(_, value)| value.to_string())
- .expect("No authorization code in redirect");
-
- // Use HTTP Basic Auth instead of form parameters
- let mut token_params = HashMap::new();
- token_params.insert("grant_type".to_string(), "authorization_code".to_string());
- token_params.insert("code".to_string(), auth_code);
- // No client_id/client_secret in form
-
- // test_client:test_secret encoded in base64
- let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ=";
-
- let result = oauth_server.handle_token(
- &token_params,
- Some(auth_header),
- Some("127.0.0.1".to_string()),
- );
- assert!(result.is_ok());
- }
-}