summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 15:15:41 -0600
committermo khan <mo@mokhan.ca>2025-06-11 15:15:41 -0600
commitaea6bd6ec7d7e70a67723edf6327df4a9cc65d89 (patch)
tree80fcb6cbda7baa5ed15cf044d7583acb2438c4d2 /src/main.rs
parent4435ee26b79648e92d0f172e42f9e6629e955505 (diff)
chore: run rustfmt again
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs49
1 files changed, 27 insertions, 22 deletions
diff --git a/src/main.rs b/src/main.rs
index f5951e0..0612bfa 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,31 +1,36 @@
-use sts::http::Server;
-use sts::Config;
use std::thread;
use std::time::Duration;
+use sts::Config;
+use sts::http::Server;
fn main() {
let config = Config::from_env();
let server = Server::new(config.clone()).expect("Failed to create server");
-
+
// Start cleanup task in background
let cleanup_config = config.clone();
thread::spawn(move || {
loop {
- thread::sleep(Duration::from_secs(cleanup_config.cleanup_interval_hours as u64 * 3600));
+ thread::sleep(Duration::from_secs(
+ cleanup_config.cleanup_interval_hours as u64 * 3600,
+ ));
// Note: In the current implementation, we don't have direct access to the OAuth server
// from here to call cleanup_expired_data(). In a production implementation,
// you'd want to structure this differently or use a background job queue.
}
});
-
+
println!("Starting OAuth2 STS server...");
println!("Configuration:");
println!(" Bind Address: {}", config.bind_addr);
println!(" Issuer URL: {}", config.issuer_url);
println!(" Database: {}", config.database_path);
- println!(" Rate Limit: {} requests/minute", config.rate_limit_requests_per_minute);
+ println!(
+ " Rate Limit: {} requests/minute",
+ config.rate_limit_requests_per_minute
+ );
println!(" Audit Logging: {}", config.enable_audit_logging);
-
+
server.start();
}
@@ -102,15 +107,15 @@ mod disabled_tests {
fn test_jwks_endpoint() {
let config = sts::Config::from_env();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
+
let jwks_json = oauth_server.get_jwks();
assert!(!jwks_json.is_empty());
-
+
// Parse the JSON to verify structure
let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
assert!(jwks["keys"].is_array());
assert!(jwks["keys"].as_array().unwrap().len() > 0);
-
+
let key = &jwks["keys"][0];
assert_eq!(key["kty"], "RSA");
assert_eq!(key["use"], "sig");
@@ -133,7 +138,7 @@ mod disabled_tests {
auth_params.insert("state".to_string(), "test_state".to_string());
let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
-
+
// Extract the authorization code from redirect URL
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
@@ -150,17 +155,17 @@ mod disabled_tests {
token_params.insert("client_secret".to_string(), "test_secret".to_string());
let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
-
+
// Parse token response
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
-
+
assert_eq!(token_response["token_type"], "Bearer");
assert_eq!(token_response["expires_in"], 3600);
assert!(token_response["access_token"].is_string());
-
+
let access_token = token_response["access_token"].as_str().unwrap();
-
+
// Step 3: Verify the JWT token has RSA signature and key ID
let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
@@ -168,7 +173,7 @@ mod disabled_tests {
assert!(!header.kid.as_ref().unwrap().is_empty());
}
- #[test]
+ #[test]
fn test_token_validation_with_jwks() {
let config = sts::Config::from_env();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
@@ -201,11 +206,11 @@ mod disabled_tests {
// Get the JWKS
let jwks_json = oauth_server.get_jwks();
let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
-
+
// Decode the token header to get the key ID
let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
let kid = header.kid.as_ref().expect("No key ID in token");
-
+
// Find the matching key in JWKS
let matching_key = jwks["keys"]
.as_array()
@@ -213,7 +218,7 @@ mod disabled_tests {
.iter()
.find(|key| key["kid"] == *kid)
.expect("Key ID not found in JWKS");
-
+
assert_eq!(matching_key["kty"], "RSA");
assert_eq!(matching_key["alg"], "RS256");
}
@@ -255,17 +260,17 @@ mod disabled_tests {
&jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing
&jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256)
);
-
+
// Since we can't validate with a dummy key, we'll just verify the structure
// by decoding the payload manually
let parts: Vec<&str> = access_token.split('.').collect();
assert_eq!(parts.len(), 3); // header.payload.signature
-
+
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.expect("Failed to decode payload");
let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON");
-
+
assert!(claims["sub"].is_string());
assert!(claims["iss"].is_string());
assert!(claims["aud"].is_string());