diff options
Diffstat (limited to 'src/oauth')
| -rw-r--r-- | src/oauth/mod.rs | 2 | ||||
| -rw-r--r-- | src/oauth/pkce.rs | 156 | ||||
| -rw-r--r-- | src/oauth/server.rs | 469 | ||||
| -rw-r--r-- | src/oauth/types.rs | 45 |
4 files changed, 620 insertions, 52 deletions
diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 3a0d861..7fd0d7b 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -1,5 +1,7 @@ +pub mod pkce; pub mod server; pub mod types; +pub use pkce::{CodeChallengeMethod, verify_code_challenge, generate_code_verifier, generate_code_challenge}; pub use server::OAuthServer; pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse}; diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs new file mode 100644 index 0000000..c943844 --- /dev/null +++ b/src/oauth/pkce.rs @@ -0,0 +1,156 @@ +use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; +use sha2::{Digest, Sha256}; +use anyhow::{anyhow, Result}; + +#[derive(Debug, Clone, PartialEq)] +pub enum CodeChallengeMethod { + Plain, + S256, +} + +impl CodeChallengeMethod { + pub fn from_str(s: &str) -> Result<Self> { + match s { + "plain" => Ok(CodeChallengeMethod::Plain), + "S256" => Ok(CodeChallengeMethod::S256), + _ => Err(anyhow!("Unsupported code challenge method: {}", s)), + } + } + + pub fn as_str(&self) -> &'static str { + match self { + CodeChallengeMethod::Plain => "plain", + CodeChallengeMethod::S256 => "S256", + } + } +} + +pub fn verify_code_challenge( + code_verifier: &str, + code_challenge: &str, + method: &CodeChallengeMethod, +) -> Result<bool> { + // Validate code verifier format (RFC 7636 Section 4.1) + if code_verifier.len() < 43 || code_verifier.len() > 128 { + return Err(anyhow!("Code verifier length must be between 43 and 128 characters")); + } + + // Code verifier must only contain unreserved characters + if !code_verifier.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' + }) { + return Err(anyhow!("Code verifier contains invalid characters")); + } + + let computed_challenge = match method { + CodeChallengeMethod::Plain => code_verifier.to_string(), + CodeChallengeMethod::S256 => { + let mut hasher = Sha256::new(); + hasher.update(code_verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hasher.finalize()) + } + }; + + Ok(computed_challenge == code_challenge) +} + +pub fn generate_code_verifier() -> String { + use rand::Rng; + let mut rng = rand::thread_rng(); + + // Generate 32 random bytes and encode them + let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect(); + URL_SAFE_NO_PAD.encode(&bytes) +} + +pub fn generate_code_challenge(verifier: &str, method: &CodeChallengeMethod) -> String { + match method { + CodeChallengeMethod::Plain => verifier.to_string(), + CodeChallengeMethod::S256 => { + let mut hasher = Sha256::new(); + hasher.update(verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hasher.finalize()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_code_challenge_method_from_str() { + assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain); + assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256); + assert!(CodeChallengeMethod::from_str("invalid").is_err()); + } + + #[test] + fn test_code_challenge_method_as_str() { + assert_eq!(CodeChallengeMethod::Plain.as_str(), "plain"); + assert_eq!(CodeChallengeMethod::S256.as_str(), "S256"); + } + + #[test] + fn test_verify_code_challenge_plain() { + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap()); + assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap()); + } + + #[test] + fn test_verify_code_challenge_s256() { + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap()); + assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap()); + } + + #[test] + fn test_verify_code_challenge_invalid_verifier() { + // Too short + assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err()); + + // Invalid characters + assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err()); + } + + #[test] + fn test_generate_code_verifier() { + let verifier = generate_code_verifier(); + assert!(verifier.len() >= 43); + assert!(verifier.len() <= 128); + + // Should only contain valid characters + assert!(verifier.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' + })); + } + + #[test] + fn test_generate_code_challenge() { + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + + let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain); + assert_eq!(plain_challenge, verifier); + + let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256); + assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); + } + + #[test] + fn test_round_trip() { + let verifier = generate_code_verifier(); + + // Test with S256 + let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256); + assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap()); + + // Test with Plain + let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain); + assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap()); + } +}
\ No newline at end of file diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 243fdba..7552f00 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,29 +1,36 @@ use crate::clients::{parse_basic_auth, ClientManager}; use crate::config::Config; +use crate::database::{Database, DbAuthCode, DbAccessToken, DbAuditLog}; use crate::keys::KeyManager; -use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse}; +use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; +use crate::oauth::types::{Claims, ErrorResponse, TokenResponse, TokenIntrospectionResponse}; +use anyhow::{anyhow, Result}; +use chrono::{Duration, Utc}; use jsonwebtoken::{encode, Algorithm, Header}; +use sha2::{Digest, Sha256}; use std::collections::HashMap; +use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; use url::Url; use uuid::Uuid; pub struct OAuthServer { config: Config, - key_manager: std::sync::Mutex<KeyManager>, - auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>, - client_manager: std::sync::Mutex<ClientManager>, + key_manager: Arc<Mutex<KeyManager>>, + client_manager: Arc<Mutex<ClientManager>>, + database: Arc<Mutex<Database>>, } impl OAuthServer { - pub fn new(config: &Config) -> Result<Self, Box<dyn std::error::Error>> { - let key_manager = KeyManager::new()?; - let client_manager = ClientManager::new(); + pub fn new(config: &Config) -> Result<Self> { + let database = Arc::new(Mutex::new(Database::new(&config.database_path)?)); + let key_manager = Arc::new(Mutex::new(KeyManager::new(database.clone())?)); + let client_manager = Arc::new(Mutex::new(ClientManager::new(database.clone())?)); Ok(Self { - key_manager: std::sync::Mutex::new(key_manager), - auth_codes: std::sync::Mutex::new(HashMap::new()), - client_manager: std::sync::Mutex::new(client_manager), + key_manager, + client_manager, + database, config: config.clone(), }) } @@ -36,7 +43,7 @@ impl OAuthServer { } } - pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> { + pub fn handle_authorize(&self, params: &HashMap<String, String>, ip_address: Option<String>) -> Result<String, String> { let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; @@ -49,48 +56,90 @@ impl OAuthServer { .get("response_type") .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + // Rate limiting check + if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/authorize") { + self.audit_log("authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string())); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + // Validate client exists - let client_manager = self.client_manager.lock().unwrap(); - let _client = client_manager - .get_client(client_id) - .ok_or_else(|| self.error_response("invalid_client", "Invalid client_id"))?; + let client = { + let mut client_manager = self.client_manager.lock().unwrap(); + match client_manager.get_client_from_db(client_id) { + Ok(Some(client)) => client, + Ok(None) => { + self.audit_log("authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None); + return Err(self.error_response("invalid_client", "Invalid client_id")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + } + }; // Validate redirect URI is registered for this client - if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) { - return Err(self.error_response("invalid_request", "Invalid redirect_uri")); + { + let mut client_manager = self.client_manager.lock().unwrap(); + if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) { + self.audit_log("authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri)); + return Err(self.error_response("invalid_request", "Invalid redirect_uri")); + } } // Validate requested scopes let scope = params.get("scope").cloned(); - if !client_manager.is_scope_valid(client_id, &scope) { - return Err(self.error_response("invalid_scope", "Invalid scope")); + { + let mut client_manager = self.client_manager.lock().unwrap(); + if !client_manager.is_scope_valid(client_id, &scope) { + self.audit_log("authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref()); + return Err(self.error_response("invalid_scope", "Invalid scope")); + } } if response_type != "code" { + self.audit_log("authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type)); return Err(self.error_response( "unsupported_response_type", "Only code response type supported", )); } + // PKCE validation (RFC 7636) + let code_challenge = params.get("code_challenge"); + let code_challenge_method = params.get("code_challenge_method") + .map(|method| CodeChallengeMethod::from_str(method)) + .transpose() + .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; + + // For public clients, PKCE is required + if client.client_id.starts_with("public_") && code_challenge.is_none() { + self.audit_log("authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None); + return Err(self.error_response("invalid_request", "PKCE required for public clients")); + } + let code = Uuid::new_v4().to_string(); - let expires_at = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - + 600; + let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration - let auth_code = AuthCode { + let db_auth_code = DbAuthCode { + id: 0, // Will be set by database + code: code.clone(), client_id: client_id.clone(), + user_id: "test_user".to_string(), // In a real implementation, this would come from authentication redirect_uri: redirect_uri.clone(), - scope: scope, + scope: scope.clone(), expires_at, - user_id: "test_user".to_string(), + created_at: Utc::now(), + is_used: false, + code_challenge: code_challenge.cloned(), + code_challenge_method: code_challenge_method.as_ref().map(|m| m.as_str().to_string()), }; + // Save to database { - let mut codes = self.auth_codes.lock().unwrap(); - codes.insert(code.clone(), auth_code); + let db = self.database.lock().unwrap(); + if let Err(_) = db.create_auth_code(&db_auth_code) { + return Err(self.error_response("server_error", "Failed to create authorization code")); + } } let mut redirect_url = Url::parse(redirect_uri) @@ -102,6 +151,8 @@ impl OAuthServer { redirect_url.query_pairs_mut().append_pair("state", state); } + self.audit_log("authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None); + Ok(redirect_url.to_string()) } @@ -109,24 +160,36 @@ impl OAuthServer { &self, params: &HashMap<String, String>, auth_header: Option<&str>, + ip_address: Option<String>, ) -> Result<String, String> { let grant_type = params .get("grant_type") .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; - if grant_type != "authorization_code" { - return Err(self.error_response( - "unsupported_grant_type", - "Only authorization_code grant type supported", - )); + match grant_type.as_str() { + "authorization_code" => self.handle_authorization_code_grant(params, auth_header, ip_address), + "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), + _ => { + self.audit_log("token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type)); + Err(self.error_response( + "unsupported_grant_type", + "Unsupported grant type", + )) + } } + } + fn handle_authorization_code_grant( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { let code = params .get("code") .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; // Client authentication - RFC 6749 Section 3.2.1 - // Clients can authenticate via HTTP Basic Auth or form parameters let (client_id, client_secret) = if let Some(auth_header) = auth_header { // HTTP Basic Authentication (preferred method) parse_basic_auth(auth_header).ok_or_else(|| { @@ -143,52 +206,293 @@ impl OAuthServer { (client_id.clone(), client_secret.clone()) }; + // Rate limiting check + if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") { + self.audit_log("token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string())); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + // Authenticate the client - let client_manager = self.client_manager.lock().unwrap(); - if !client_manager.authenticate_client(&client_id, &client_secret) { - return Err(self.error_response("invalid_client", "Client authentication failed")); + { + let mut client_manager = self.client_manager.lock().unwrap(); + if !client_manager.authenticate_client(&client_id, &client_secret) { + self.audit_log("token_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None); + return Err(self.error_response("invalid_client", "Client authentication failed")); + } } + // Get and validate authorization code let auth_code = { - let mut codes = self.auth_codes.lock().unwrap(); - codes.remove(code).ok_or_else(|| { - self.error_response("invalid_grant", "Invalid or expired authorization code") - })? + let db = self.database.lock().unwrap(); + match db.get_auth_code(code) { + Ok(Some(auth_code)) => auth_code, + Ok(None) => { + self.audit_log("token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code)); + return Err(self.error_response("invalid_grant", "Invalid or expired authorization code")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + } }; + // Validate code hasn't been used and hasn't expired + if auth_code.is_used { + self.audit_log("token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code)); + return Err(self.error_response("invalid_grant", "Authorization code already used")); + } + + if Utc::now() > auth_code.expires_at { + self.audit_log("token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code)); + return Err(self.error_response("invalid_grant", "Authorization code expired")); + } + if auth_code.client_id != client_id { + self.audit_log("token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None); return Err(self.error_response("invalid_grant", "Client ID mismatch")); } - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); + // PKCE validation if code challenge was provided + if let Some(code_challenge) = &auth_code.code_challenge { + let code_verifier = params.get("code_verifier").ok_or_else(|| { + self.error_response("invalid_request", "Missing code_verifier for PKCE") + })?; - if now > auth_code.expires_at { - return Err(self.error_response("invalid_grant", "Authorization code expired")); + let challenge_method = auth_code.code_challenge_method + .as_ref() + .and_then(|method| CodeChallengeMethod::from_str(method).ok()) + .unwrap_or(CodeChallengeMethod::Plain); + + if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) { + self.audit_log("token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None); + return Err(self.error_response("invalid_grant", "PKCE verification failed")); + } + } + + // Mark code as used + { + let db = self.database.lock().unwrap(); + if let Err(_) = db.mark_auth_code_used(code) { + return Err(self.error_response("server_error", "Failed to mark code as used")); + } } - let access_token = - self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope)?; + // Generate tokens + let token_id = Uuid::new_v4().to_string(); + let access_token = self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope, &token_id)?; + let refresh_token = self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; + + // Store token in database for revocation/introspection + let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); + let db_access_token = DbAccessToken { + id: 0, + token_id: token_id.clone(), + client_id: client_id.clone(), + user_id: auth_code.user_id.clone(), + scope: auth_code.scope.clone(), + expires_at: Utc::now() + Duration::hours(1), + created_at: Utc::now(), + is_revoked: false, + token_hash, + }; + + { + let db = self.database.lock().unwrap(); + if let Err(_) = db.create_access_token(&db_access_token) { + return Err(self.error_response("server_error", "Failed to store access token")); + } + } let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, - refresh_token: None, + refresh_token: Some(refresh_token), scope: auth_code.scope, }; + self.audit_log("token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None); + serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) } + fn handle_refresh_token_grant( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ip_address: Option<String>, + ) -> Result<String, String> { + let _refresh_token = params + .get("refresh_token") + .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; + + // Client authentication + let (client_id, client_secret) = if let Some(auth_header) = auth_header { + parse_basic_auth(auth_header).ok_or_else(|| { + self.error_response("invalid_client", "Invalid Authorization header") + })? + } else { + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + let client_secret = params + .get("client_secret") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?; + (client_id.clone(), client_secret.clone()) + }; + + // Authenticate the client + { + let mut client_manager = self.client_manager.lock().unwrap(); + if !client_manager.authenticate_client(&client_id, &client_secret) { + self.audit_log("refresh_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None); + return Err(self.error_response("invalid_client", "Client authentication failed")); + } + } + + // Validate refresh token (implementation would verify token and get user info) + // For now, return a simple refresh token response + let new_token_id = Uuid::new_v4().to_string(); + let access_token = self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; + let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?; + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: Some(new_refresh_token), + scope: None, + }; + + self.audit_log("refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None); + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + pub fn handle_token_introspection( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<String, String> { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the introspection request + let (client_id, client_secret) = if let Some(auth_header) = auth_header { + parse_basic_auth(auth_header).ok_or_else(|| { + self.error_response("invalid_client", "Invalid Authorization header") + })? + } else { + return Err(self.error_response("invalid_client", "Client authentication required")); + }; + + { + let mut client_manager = self.client_manager.lock().unwrap(); + if !client_manager.authenticate_client(&client_id, &client_secret) { + return Err(self.error_response("invalid_client", "Client authentication failed")); + } + } + + // Look up token in database + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + let db_token = { + let db = self.database.lock().unwrap(); + db.get_access_token(&token_hash).ok().flatten() + }; + + let response = if let Some(db_token) = db_token { + if !db_token.is_revoked && Utc::now() < db_token.expires_at { + TokenIntrospectionResponse { + active: true, + client_id: Some(db_token.client_id.clone()), + username: Some(db_token.user_id.clone()), + scope: db_token.scope.clone(), + exp: Some(db_token.expires_at.timestamp() as u64), + iat: Some(db_token.created_at.timestamp() as u64), + sub: Some(db_token.user_id), + aud: Some(db_token.client_id), + iss: Some(self.config.issuer_url.clone()), + jti: Some(db_token.token_id), + } + } else { + TokenIntrospectionResponse { + active: false, + client_id: None, + username: None, + scope: None, + exp: None, + iat: None, + sub: None, + aud: None, + iss: None, + jti: None, + } + } + } else { + TokenIntrospectionResponse { + active: false, + client_id: None, + username: None, + scope: None, + exp: None, + iat: None, + sub: None, + aud: None, + iss: None, + jti: None, + } + }; + + serde_json::to_string(&response) + .map_err(|_| self.error_response("server_error", "Failed to serialize response")) + } + + pub fn handle_token_revocation( + &self, + params: &HashMap<String, String>, + auth_header: Option<&str>, + ) -> Result<(), String> { + let token = params + .get("token") + .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; + + // Authenticate the client making the revocation request + let (client_id, client_secret) = if let Some(auth_header) = auth_header { + parse_basic_auth(auth_header).ok_or_else(|| { + self.error_response("invalid_client", "Invalid Authorization header") + })? + } else { + return Err(self.error_response("invalid_client", "Client authentication required")); + }; + + { + let mut client_manager = self.client_manager.lock().unwrap(); + if !client_manager.authenticate_client(&client_id, &client_secret) { + return Err(self.error_response("invalid_client", "Client authentication failed")); + } + } + + // Revoke token in database + let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); + { + let db = self.database.lock().unwrap(); + let _ = db.revoke_access_token(&token_hash); // Ignore errors as per RFC 7009 + } + + self.audit_log("token_revoked", Some(&client_id), None, None, true, None); + + Ok(()) + } + fn generate_access_token( &self, user_id: &str, client_id: &str, scope: &Option<String>, + token_id: &str, ) -> Result<String, String> { let mut key_manager = self.key_manager.lock().unwrap(); @@ -215,6 +519,7 @@ impl OAuthServer { exp: now + 3600, iat: now, scope: scope.clone(), + jti: Some(token_id.to_string()), }; let mut header = Header::new(Algorithm::RS256); @@ -224,11 +529,71 @@ impl OAuthServer { .map_err(|_| self.error_response("server_error", "Failed to generate token")) } + fn generate_refresh_token( + &self, + _client_id: &str, + _user_id: &str, + _scope: &Option<String>, + ) -> Result<String, String> { + // For now, return a simple UUID-based refresh token + // In production, this should be a proper JWT or encrypted token + Ok(Uuid::new_v4().to_string()) + } + + fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { + let db = self.database.lock().unwrap(); + let count = db.increment_rate_limit(identifier, endpoint, 1)?; + + if count > self.config.rate_limit_requests_per_minute as i32 { + return Err(anyhow!("Rate limit exceeded")); + } + + Ok(()) + } + + fn audit_log(&self, event_type: &str, client_id: Option<&str>, user_id: Option<&str>, ip_address: Option<&str>, success: bool, details: Option<&str>) { + if !self.config.enable_audit_logging { + return; + } + + let log = DbAuditLog { + id: 0, + event_type: event_type.to_string(), + client_id: client_id.map(|s| s.to_string()), + user_id: user_id.map(|s| s.to_string()), + ip_address: ip_address.map(|s| s.to_string()), + user_agent: None, // Could be passed in from HTTP layer + details: details.map(|s| s.to_string()), + created_at: Utc::now(), + success, + }; + + let db = self.database.lock().unwrap(); + let _ = db.create_audit_log(&log); // Ignore errors in audit logging + } + fn error_response(&self, error: &str, description: &str) -> String { let error_resp = ErrorResponse { error: error.to_string(), error_description: Some(description.to_string()), + error_uri: None, }; serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) } -} + + // Cleanup expired data + pub fn cleanup_expired_data(&self) -> Result<()> { + let db = self.database.lock().unwrap(); + + // Cleanup expired authorization codes + let _ = db.cleanup_expired_codes(); + + // Cleanup expired tokens + let _ = db.cleanup_expired_tokens(); + + // Cleanup old audit logs (keep for 30 days) + let _ = db.cleanup_old_audit_logs(30); + + Ok(()) + } +}
\ No newline at end of file diff --git a/src/oauth/types.rs b/src/oauth/types.rs index 6c62edf..0f9be5c 100644 --- a/src/oauth/types.rs +++ b/src/oauth/types.rs @@ -1,4 +1,5 @@ use serde::{Deserialize, Serialize}; +use crate::oauth::pkce::CodeChallengeMethod; #[derive(Debug, Serialize, Deserialize)] pub struct Claims { @@ -9,6 +10,8 @@ pub struct Claims { pub iat: u64, #[serde(skip_serializing_if = "Option::is_none")] pub scope: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub jti: Option<String>, // JWT ID for token tracking } #[derive(Debug, Serialize, Deserialize)] @@ -27,6 +30,8 @@ pub struct ErrorResponse { pub error: String, #[serde(skip_serializing_if = "Option::is_none")] pub error_description: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_uri: Option<String>, } #[derive(Debug, Clone)] @@ -36,4 +41,44 @@ pub struct AuthCode { pub scope: Option<String>, pub expires_at: u64, pub user_id: String, + // PKCE support + pub code_challenge: Option<String>, + pub code_challenge_method: Option<CodeChallengeMethod>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TokenIntrospectionRequest { + pub token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_type_hint: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TokenIntrospectionResponse { + pub active: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_id: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub username: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub exp: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub iat: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub sub: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub aud: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub iss: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub jti: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TokenRevocationRequest { + pub token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_type_hint: Option<String>, } |
