diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/http/mod.rs | 16 | ||||
| -rw-r--r-- | src/keys.rs | 29 | ||||
| -rw-r--r-- | src/oauth/mod.rs | 2 | ||||
| -rw-r--r-- | src/oauth/server.rs | 18 | ||||
| -rw-r--r-- | src/oauth/types.rs | 2 |
5 files changed, 39 insertions, 28 deletions
diff --git a/src/http/mod.rs b/src/http/mod.rs index 7b1b983..11887ae 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -73,9 +73,7 @@ impl Server { match (method, path) { ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), - ("GET", "/.well-known/oauth-authorization-server") => { - self.handle_metadata(&mut stream) - } + ("GET", "/.well-known/oauth-authorization-server") => self.handle_metadata(&mut stream), ("GET", "/jwks") => self.handle_jwks(&mut stream), ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params), ("POST", "/token") => self.handle_token(&mut stream, &request), @@ -166,11 +164,14 @@ impl Server { fn handle_token(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); - + // Extract Authorization header from request let auth_header = self.extract_auth_header(request); - match self.oauth_server.handle_token(&form_params, auth_header.as_deref()) { + match self + .oauth_server + .handle_token(&form_params, auth_header.as_deref()) + { Ok(token_response) => { self.send_json_response(stream, 200, "OK", &token_response); } @@ -206,11 +207,12 @@ impl Server { fn extract_auth_header(&self, request: &str) -> Option<String> { let lines: Vec<&str> = request.lines().collect(); - for line in lines.iter().skip(1) { // Skip the request line + for line in lines.iter().skip(1) { + // Skip the request line if line.to_lowercase().starts_with("authorization:") { return Some(line[14..].trim().to_string()); // Skip "Authorization: " } } None } -}
\ No newline at end of file +} diff --git a/src/keys.rs b/src/keys.rs index 10c9697..88060f3 100644 --- a/src/keys.rs +++ b/src/keys.rs @@ -169,8 +169,10 @@ mod tests { fn test_key_generation() { let mut manager = KeyManager::new().expect("Failed to create key manager"); let initial_key_count = manager.keys.len(); - - let new_kid = manager.generate_new_key().expect("Failed to generate new key"); + + let new_kid = manager + .generate_new_key() + .expect("Failed to generate new key"); assert_eq!(manager.keys.len(), initial_key_count + 1); assert_eq!(manager.current_key_id, Some(new_kid.clone())); assert!(manager.get_key(&new_kid).is_some()); @@ -180,7 +182,7 @@ mod tests { fn test_jwks_generation() { let manager = KeyManager::new().expect("Failed to create key manager"); let jwks = manager.get_jwks().expect("Failed to get JWKS"); - + assert_eq!(jwks.keys.len(), 1); let key = &jwks.keys[0]; assert_eq!(key.kty, "RSA"); @@ -195,9 +197,9 @@ mod tests { fn test_key_rotation() { let mut manager = KeyManager::new().expect("Failed to create key manager"); let original_kid = manager.current_key_id.clone().unwrap(); - + manager.rotate_keys().expect("Failed to rotate keys"); - + let new_kid = manager.current_key_id.clone().unwrap(); assert_ne!(original_kid, new_kid); assert_eq!(manager.keys.len(), 2); // Should have both old and new keys @@ -215,22 +217,23 @@ mod tests { #[test] fn test_should_rotate_old_key() { let mut manager = KeyManager::new().expect("Failed to create key manager"); - + // Manually modify the current key's creation time to be old if let Some(current_kid) = manager.current_key_id.clone() { if let Some(key_pair) = manager.keys.get_mut(¤t_kid) { let old_time = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() - .as_secs() - 86401; // 1 day + 1 second ago - + .as_secs() + - 86401; // 1 day + 1 second ago + // We need to recreate the key pair with the old timestamp let mut old_key_pair = key_pair.clone(); old_key_pair.created_at = old_time; manager.keys.insert(current_kid, old_key_pair); } } - + // Should need rotation since key is older than rotation interval assert!(manager.should_rotate()); } @@ -239,11 +242,11 @@ mod tests { fn test_cleanup_old_keys() { let mut manager = KeyManager::new().expect("Failed to create key manager"); let original_kid = manager.current_key_id.clone().unwrap(); - + // Generate a new key (so we have 2 keys) manager.rotate_keys().expect("Failed to rotate keys"); assert_eq!(manager.keys.len(), 2); - + // Cleanup with max_age 0 should remove old keys but keep current manager.cleanup_old_keys(0); assert_eq!(manager.keys.len(), 1); @@ -256,10 +259,10 @@ mod tests { let mut manager = KeyManager::new().expect("Failed to create key manager"); manager.rotate_keys().expect("Failed to rotate keys"); manager.rotate_keys().expect("Failed to rotate keys"); - + let jwks = manager.get_jwks().expect("Failed to get JWKS"); assert_eq!(jwks.keys.len(), 3); // Should have 3 keys - + // All keys should have unique key IDs let mut kids: Vec<String> = jwks.keys.iter().map(|k| k.kid.clone()).collect(); kids.sort(); diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index d6717c2..3a0d861 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -2,4 +2,4 @@ pub mod server; pub mod types; pub use server::OAuthServer; -pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse};
\ No newline at end of file +pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse}; diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 3ee0bb3..243fdba 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,8 +1,8 @@ -use crate::clients::{ClientManager, parse_basic_auth}; +use crate::clients::{parse_basic_auth, ClientManager}; use crate::config::Config; use crate::keys::KeyManager; use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse}; -use jsonwebtoken::{Algorithm, Header, encode}; +use jsonwebtoken::{encode, Algorithm, Header}; use std::collections::HashMap; use std::time::{SystemTime, UNIX_EPOCH}; use url::Url; @@ -51,7 +51,8 @@ impl OAuthServer { // Validate client exists let client_manager = self.client_manager.lock().unwrap(); - let _client = client_manager.get_client(client_id) + let _client = client_manager + .get_client(client_id) .ok_or_else(|| self.error_response("invalid_client", "Invalid client_id"))?; // Validate redirect URI is registered for this client @@ -104,7 +105,11 @@ impl OAuthServer { Ok(redirect_url.to_string()) } - pub fn handle_token(&self, params: &HashMap<String, String>, auth_header: Option<&str>) -> Result<String, String> { + pub fn handle_token( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<String, String> { let grant_type = params .get("grant_type") .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; @@ -124,8 +129,9 @@ impl OAuthServer { // Clients can authenticate via HTTP Basic Auth or form parameters let (client_id, client_secret) = if let Some(auth_header) = auth_header { // HTTP Basic Authentication (preferred method) - parse_basic_auth(auth_header) - .ok_or_else(|| self.error_response("invalid_client", "Invalid Authorization header"))? + parse_basic_auth(auth_header).ok_or_else(|| { + self.error_response("invalid_client", "Invalid Authorization header") + })? } else { // Form-based authentication (fallback) let client_id = params diff --git a/src/oauth/types.rs b/src/oauth/types.rs index 685af02..6c62edf 100644 --- a/src/oauth/types.rs +++ b/src/oauth/types.rs @@ -36,4 +36,4 @@ pub struct AuthCode { pub scope: Option<String>, pub expires_at: u64, pub user_id: String, -}
\ No newline at end of file +} |
