diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 15:12:59 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 15:12:59 -0600 |
| commit | 4435ee26b79648e92d0f172e42f9e6629e955505 (patch) | |
| tree | 0720fd07c879a58672fcfcb2e45ed1161430f039 /src/oauth | |
| parent | 39c67cfc6c74bf4b26ba455f3adda1241aea35ea (diff) | |
chore: rustfmt and include Connection: header in responses
Diffstat (limited to 'src/oauth')
| -rw-r--r-- | src/oauth/mod.rs | 4 | ||||
| -rw-r--r-- | src/oauth/pkce.rs | 57 | ||||
| -rw-r--r-- | src/oauth/server.rs | 240 | ||||
| -rw-r--r-- | src/oauth/types.rs | 2 |
4 files changed, 237 insertions, 66 deletions
diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 7fd0d7b..4b18bb3 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -2,6 +2,8 @@ pub mod pkce; pub mod server; pub mod types; -pub use pkce::{CodeChallengeMethod, verify_code_challenge, generate_code_verifier, generate_code_challenge}; +pub use pkce::{ + generate_code_challenge, generate_code_verifier, verify_code_challenge, CodeChallengeMethod, +}; pub use server::OAuthServer; pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse}; diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs index c943844..406d364 100644 --- a/src/oauth/pkce.rs +++ b/src/oauth/pkce.rs @@ -1,6 +1,6 @@ -use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; -use sha2::{Digest, Sha256}; use anyhow::{anyhow, Result}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use sha2::{Digest, Sha256}; #[derive(Debug, Clone, PartialEq)] pub enum CodeChallengeMethod { @@ -32,13 +32,16 @@ pub fn verify_code_challenge( ) -> Result<bool> { // Validate code verifier format (RFC 7636 Section 4.1) if code_verifier.len() < 43 || code_verifier.len() > 128 { - return Err(anyhow!("Code verifier length must be between 43 and 128 characters")); + return Err(anyhow!( + "Code verifier length must be between 43 and 128 characters" + )); } // Code verifier must only contain unreserved characters - if !code_verifier.chars().all(|c| { - c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' - }) { + if !code_verifier + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~') + { return Err(anyhow!("Code verifier contains invalid characters")); } @@ -57,7 +60,7 @@ pub fn verify_code_challenge( pub fn generate_code_verifier() -> String { use rand::Rng; let mut rng = rand::thread_rng(); - + // Generate 32 random bytes and encode them let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect(); URL_SAFE_NO_PAD.encode(&bytes) @@ -80,8 +83,14 @@ mod tests { #[test] fn test_code_challenge_method_from_str() { - assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain); - assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256); + assert_eq!( + CodeChallengeMethod::from_str("plain").unwrap(), + CodeChallengeMethod::Plain + ); + assert_eq!( + CodeChallengeMethod::from_str("S256").unwrap(), + CodeChallengeMethod::S256 + ); assert!(CodeChallengeMethod::from_str("invalid").is_err()); } @@ -95,7 +104,7 @@ mod tests { fn test_verify_code_challenge_plain() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; - + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap()); assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap()); } @@ -104,7 +113,7 @@ mod tests { fn test_verify_code_challenge_s256() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; - + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap()); assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap()); } @@ -113,9 +122,14 @@ mod tests { fn test_verify_code_challenge_invalid_verifier() { // Too short assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err()); - + // Invalid characters - assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err()); + assert!(verify_code_challenge( + "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", + "challenge", + &CodeChallengeMethod::Plain + ) + .is_err()); } #[test] @@ -123,7 +137,7 @@ mod tests { let verifier = generate_code_verifier(); assert!(verifier.len() >= 43); assert!(verifier.len() <= 128); - + // Should only contain valid characters assert!(verifier.chars().all(|c| { c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' @@ -133,24 +147,27 @@ mod tests { #[test] fn test_generate_code_challenge() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; - + let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain); assert_eq!(plain_challenge, verifier); - + let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256); - assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); + assert_eq!( + s256_challenge, + "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + ); } #[test] fn test_round_trip() { let verifier = generate_code_verifier(); - + // Test with S256 let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256); assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap()); - + // Test with Plain let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain); assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap()); } -}
\ No newline at end of file +} diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 7552f00..7fd8b9c 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,9 +1,9 @@ use crate::clients::{parse_basic_auth, ClientManager}; use crate::config::Config; -use crate::database::{Database, DbAuthCode, DbAccessToken, DbAuditLog}; +use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode}; use crate::keys::KeyManager; -use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; -use crate::oauth::types::{Claims, ErrorResponse, TokenResponse, TokenIntrospectionResponse}; +use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod}; +use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse}; use anyhow::{anyhow, Result}; use chrono::{Duration, Utc}; use jsonwebtoken::{encode, Algorithm, Header}; @@ -43,7 +43,11 @@ impl OAuthServer { } } - pub fn handle_authorize(&self, params: &HashMap<String, String>, ip_address: Option<String>) -> Result<String, String> { + pub fn handle_authorize( + &self, + params: &HashMap<String, String>, + ip_address: Option<String>, + ) -> Result<String, String> { let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; @@ -58,7 +62,14 @@ impl OAuthServer { // Rate limiting check if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/authorize") { - self.audit_log("authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string())); + self.audit_log( + "authorize_rate_limited", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } @@ -68,7 +79,14 @@ impl OAuthServer { match client_manager.get_client_from_db(client_id) { Ok(Some(client)) => client, Ok(None) => { - self.audit_log("authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "authorize_invalid_client", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_client", "Invalid client_id")); } Err(_) => { @@ -81,7 +99,14 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) { - self.audit_log("authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri)); + self.audit_log( + "authorize_invalid_redirect_uri", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(redirect_uri), + ); return Err(self.error_response("invalid_request", "Invalid redirect_uri")); } } @@ -91,13 +116,27 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.is_scope_valid(client_id, &scope) { - self.audit_log("authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref()); + self.audit_log( + "authorize_invalid_scope", + Some(client_id), + None, + ip_address.as_deref(), + false, + scope.as_deref(), + ); return Err(self.error_response("invalid_scope", "Invalid scope")); } } if response_type != "code" { - self.audit_log("authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type)); + self.audit_log( + "authorize_unsupported_response_type", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(response_type), + ); return Err(self.error_response( "unsupported_response_type", "Only code response type supported", @@ -106,14 +145,22 @@ impl OAuthServer { // PKCE validation (RFC 7636) let code_challenge = params.get("code_challenge"); - let code_challenge_method = params.get("code_challenge_method") + let code_challenge_method = params + .get("code_challenge_method") .map(|method| CodeChallengeMethod::from_str(method)) .transpose() .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; // For public clients, PKCE is required if client.client_id.starts_with("public_") && code_challenge.is_none() { - self.audit_log("authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "authorize_missing_pkce", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_request", "PKCE required for public clients")); } @@ -131,14 +178,18 @@ impl OAuthServer { created_at: Utc::now(), is_used: false, code_challenge: code_challenge.cloned(), - code_challenge_method: code_challenge_method.as_ref().map(|m| m.as_str().to_string()), + code_challenge_method: code_challenge_method + .as_ref() + .map(|m| m.as_str().to_string()), }; // Save to database { let db = self.database.lock().unwrap(); if let Err(_) = db.create_auth_code(&db_auth_code) { - return Err(self.error_response("server_error", "Failed to create authorization code")); + return Err( + self.error_response("server_error", "Failed to create authorization code") + ); } } @@ -151,7 +202,14 @@ impl OAuthServer { redirect_url.query_pairs_mut().append_pair("state", state); } - self.audit_log("authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None); + self.audit_log( + "authorize_success", + Some(client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); Ok(redirect_url.to_string()) } @@ -167,14 +225,20 @@ impl OAuthServer { .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; match grant_type.as_str() { - "authorization_code" => self.handle_authorization_code_grant(params, auth_header, ip_address), + "authorization_code" => { + self.handle_authorization_code_grant(params, auth_header, ip_address) + } "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), _ => { - self.audit_log("token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type)); - Err(self.error_response( - "unsupported_grant_type", - "Unsupported grant type", - )) + self.audit_log( + "token_unsupported_grant_type", + None, + None, + ip_address.as_deref(), + false, + Some(grant_type), + ); + Err(self.error_response("unsupported_grant_type", "Unsupported grant type")) } } } @@ -208,7 +272,14 @@ impl OAuthServer { // Rate limiting check if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") { - self.audit_log("token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string())); + self.audit_log( + "token_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } @@ -216,7 +287,14 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { - self.audit_log("token_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "token_invalid_client", + Some(&client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_client", "Client authentication failed")); } } @@ -227,8 +305,16 @@ impl OAuthServer { match db.get_auth_code(code) { Ok(Some(auth_code)) => auth_code, Ok(None) => { - self.audit_log("token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code)); - return Err(self.error_response("invalid_grant", "Invalid or expired authorization code")); + self.audit_log( + "token_invalid_code", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self + .error_response("invalid_grant", "Invalid or expired authorization code")); } Err(_) => { return Err(self.error_response("server_error", "Internal server error")); @@ -238,17 +324,38 @@ impl OAuthServer { // Validate code hasn't been used and hasn't expired if auth_code.is_used { - self.audit_log("token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code)); + self.audit_log( + "token_code_reuse", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); return Err(self.error_response("invalid_grant", "Authorization code already used")); } if Utc::now() > auth_code.expires_at { - self.audit_log("token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code)); + self.audit_log( + "token_code_expired", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); return Err(self.error_response("invalid_grant", "Authorization code expired")); } if auth_code.client_id != client_id { - self.audit_log("token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None); + self.audit_log( + "token_client_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_grant", "Client ID mismatch")); } @@ -258,13 +365,22 @@ impl OAuthServer { self.error_response("invalid_request", "Missing code_verifier for PKCE") })?; - let challenge_method = auth_code.code_challenge_method + let challenge_method = auth_code + .code_challenge_method .as_ref() .and_then(|method| CodeChallengeMethod::from_str(method).ok()) .unwrap_or(CodeChallengeMethod::Plain); - if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) { - self.audit_log("token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None); + if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) + { + self.audit_log( + "token_pkce_verification_failed", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_grant", "PKCE verification failed")); } } @@ -279,8 +395,14 @@ impl OAuthServer { // Generate tokens let token_id = Uuid::new_v4().to_string(); - let access_token = self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope, &token_id)?; - let refresh_token = self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; + let access_token = self.generate_access_token( + &auth_code.user_id, + &client_id, + &auth_code.scope, + &token_id, + )?; + let refresh_token = + self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; // Store token in database for revocation/introspection let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); @@ -311,7 +433,14 @@ impl OAuthServer { scope: auth_code.scope, }; - self.audit_log("token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None); + self.audit_log( + "token_success", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + true, + None, + ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) @@ -346,7 +475,14 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { - self.audit_log("refresh_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "refresh_invalid_client", + Some(&client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_client", "Client authentication failed")); } } @@ -354,7 +490,8 @@ impl OAuthServer { // Validate refresh token (implementation would verify token and get user info) // For now, return a simple refresh token response let new_token_id = Uuid::new_v4().to_string(); - let access_token = self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; + let access_token = + self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?; let token_response = TokenResponse { @@ -365,7 +502,14 @@ impl OAuthServer { scope: None, }; - self.audit_log("refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None); + self.audit_log( + "refresh_success", + Some(&client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) @@ -543,15 +687,23 @@ impl OAuthServer { fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { let db = self.database.lock().unwrap(); let count = db.increment_rate_limit(identifier, endpoint, 1)?; - + if count > self.config.rate_limit_requests_per_minute as i32 { return Err(anyhow!("Rate limit exceeded")); } - + Ok(()) } - fn audit_log(&self, event_type: &str, client_id: Option<&str>, user_id: Option<&str>, ip_address: Option<&str>, success: bool, details: Option<&str>) { + fn audit_log( + &self, + event_type: &str, + client_id: Option<&str>, + user_id: Option<&str>, + ip_address: Option<&str>, + success: bool, + details: Option<&str>, + ) { if !self.config.enable_audit_logging { return; } @@ -584,16 +736,16 @@ impl OAuthServer { // Cleanup expired data pub fn cleanup_expired_data(&self) -> Result<()> { let db = self.database.lock().unwrap(); - + // Cleanup expired authorization codes let _ = db.cleanup_expired_codes(); - + // Cleanup expired tokens let _ = db.cleanup_expired_tokens(); - + // Cleanup old audit logs (keep for 30 days) let _ = db.cleanup_old_audit_logs(30); - + Ok(()) } -}
\ No newline at end of file +} diff --git a/src/oauth/types.rs b/src/oauth/types.rs index 0f9be5c..4f2c363 100644 --- a/src/oauth/types.rs +++ b/src/oauth/types.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use crate::oauth::pkce::CodeChallengeMethod; +use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] pub struct Claims { |
