use crate::clients::{ClientManager, parse_basic_auth}; use crate::config::Config; use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode}; use crate::keys::KeyManager; use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse}; use anyhow::{Result, anyhow}; use chrono::{Duration, Utc}; use jsonwebtoken::{Algorithm, Header, encode}; use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; use url::Url; use uuid::Uuid; pub struct OAuthServer { config: Config, key_manager: Arc>, client_manager: Arc>, database: Arc>, } impl OAuthServer { pub fn new(config: &Config) -> Result { let database = Arc::new(Mutex::new(Database::new(&config.database_path)?)); let key_manager = Arc::new(Mutex::new(KeyManager::new(database.clone())?)); let client_manager = Arc::new(Mutex::new(ClientManager::new(database.clone())?)); Ok(Self { key_manager, client_manager, database, config: config.clone(), }) } pub fn get_jwks(&self) -> String { let key_manager = self.key_manager.lock().unwrap(); match key_manager.get_jwks() { Ok(jwks) => serde_json::to_string(&jwks).unwrap_or_else(|_| "{}".to_string()), Err(_) => serde_json::json!({"keys": []}).to_string(), } } pub fn handle_authorize( &self, params: &HashMap, ip_address: Option, ) -> Result { let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; let redirect_uri = params .get("redirect_uri") .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; let response_type = params .get("response_type") .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; // Rate limiting check if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/authorize") { self.audit_log( "authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string()), ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } // Validate client exists let client = { let mut client_manager = self.client_manager.lock().unwrap(); match client_manager.get_client_from_db(client_id) { Ok(Some(client)) => client, Ok(None) => { self.audit_log( "authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_client", "Invalid client_id")); } Err(_) => { return Err(self.error_response("server_error", "Internal server error")); } } }; // Validate redirect URI is registered for this client { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) { self.audit_log( "authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri), ); return Err(self.error_response("invalid_request", "Invalid redirect_uri")); } } // Validate requested scopes let scope = params.get("scope").cloned(); { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.is_scope_valid(client_id, &scope) { self.audit_log( "authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref(), ); return Err(self.error_response("invalid_scope", "Invalid scope")); } } if response_type != "code" { self.audit_log( "authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type), ); return Err(self.error_response( "unsupported_response_type", "Only code response type supported", )); } // PKCE validation (RFC 7636) let code_challenge = params.get("code_challenge"); let code_challenge_method = params .get("code_challenge_method") .map(|method| CodeChallengeMethod::from_str(method)) .transpose() .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; // For public clients, PKCE is required if client.client_id.starts_with("public_") && code_challenge.is_none() { self.audit_log( "authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_request", "PKCE required for public clients")); } let code = Uuid::new_v4().to_string(); let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration let db_auth_code = DbAuthCode { id: 0, // Will be set by database code: code.clone(), client_id: client_id.clone(), user_id: "test_user".to_string(), // In a real implementation, this would come from authentication redirect_uri: redirect_uri.clone(), scope: scope.clone(), expires_at, created_at: Utc::now(), is_used: false, code_challenge: code_challenge.cloned(), code_challenge_method: code_challenge_method .as_ref() .map(|m| m.as_str().to_string()), }; // Save to database { let db = self.database.lock().unwrap(); if let Err(_) = db.create_auth_code(&db_auth_code) { return Err( self.error_response("server_error", "Failed to create authorization code") ); } } let mut redirect_url = Url::parse(redirect_uri) .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; redirect_url.query_pairs_mut().append_pair("code", &code); if let Some(state) = params.get("state") { redirect_url.query_pairs_mut().append_pair("state", state); } self.audit_log( "authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None, ); Ok(redirect_url.to_string()) } pub fn handle_token( &self, params: &HashMap, auth_header: Option<&str>, ip_address: Option, ) -> Result { let grant_type = params .get("grant_type") .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; match grant_type.as_str() { "authorization_code" => { self.handle_authorization_code_grant(params, auth_header, ip_address) } "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), _ => { self.audit_log( "token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type), ); Err(self.error_response("unsupported_grant_type", "Unsupported grant type")) } } } fn handle_authorization_code_grant( &self, params: &HashMap, auth_header: Option<&str>, ip_address: Option, ) -> Result { let code = params .get("code") .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; // Client authentication - RFC 6749 Section 3.2.1 let (client_id, client_secret) = if let Some(auth_header) = auth_header { // HTTP Basic Authentication (preferred method) parse_basic_auth(auth_header).ok_or_else(|| { self.error_response("invalid_client", "Invalid Authorization header") })? } else { // Form-based authentication (fallback) let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; let client_secret = params .get("client_secret") .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?; (client_id.clone(), client_secret.clone()) }; // Rate limiting check if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") { self.audit_log( "token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string()), ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } // Authenticate the client { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { self.audit_log( "token_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_client", "Client authentication failed")); } } // Get and validate authorization code let auth_code = { let db = self.database.lock().unwrap(); match db.get_auth_code(code) { Ok(Some(auth_code)) => auth_code, Ok(None) => { self.audit_log( "token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code), ); return Err(self .error_response("invalid_grant", "Invalid or expired authorization code")); } Err(_) => { return Err(self.error_response("server_error", "Internal server error")); } } }; // Validate code hasn't been used and hasn't expired if auth_code.is_used { self.audit_log( "token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code), ); return Err(self.error_response("invalid_grant", "Authorization code already used")); } if Utc::now() > auth_code.expires_at { self.audit_log( "token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code), ); return Err(self.error_response("invalid_grant", "Authorization code expired")); } if auth_code.client_id != client_id { self.audit_log( "token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_grant", "Client ID mismatch")); } // PKCE validation if code challenge was provided if let Some(code_challenge) = &auth_code.code_challenge { let code_verifier = params.get("code_verifier").ok_or_else(|| { self.error_response("invalid_request", "Missing code_verifier for PKCE") })?; let challenge_method = auth_code .code_challenge_method .as_ref() .and_then(|method| CodeChallengeMethod::from_str(method).ok()) .unwrap_or(CodeChallengeMethod::Plain); if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) { self.audit_log( "token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_grant", "PKCE verification failed")); } } // Mark code as used { let db = self.database.lock().unwrap(); if let Err(_) = db.mark_auth_code_used(code) { return Err(self.error_response("server_error", "Failed to mark code as used")); } } // Generate tokens let token_id = Uuid::new_v4().to_string(); let access_token = self.generate_access_token( &auth_code.user_id, &client_id, &auth_code.scope, &token_id, )?; let refresh_token = self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; // Store token in database for revocation/introspection let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); let db_access_token = DbAccessToken { id: 0, token_id: token_id.clone(), client_id: client_id.clone(), user_id: auth_code.user_id.clone(), scope: auth_code.scope.clone(), expires_at: Utc::now() + Duration::hours(1), created_at: Utc::now(), is_revoked: false, token_hash, }; { let db = self.database.lock().unwrap(); if let Err(_) = db.create_access_token(&db_access_token) { return Err(self.error_response("server_error", "Failed to store access token")); } } let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, refresh_token: Some(refresh_token), scope: auth_code.scope, }; self.audit_log( "token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None, ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) } fn handle_refresh_token_grant( &self, params: &HashMap, auth_header: Option<&str>, ip_address: Option, ) -> Result { let _refresh_token = params .get("refresh_token") .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; // Client authentication let (client_id, client_secret) = if let Some(auth_header) = auth_header { parse_basic_auth(auth_header).ok_or_else(|| { self.error_response("invalid_client", "Invalid Authorization header") })? } else { let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; let client_secret = params .get("client_secret") .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?; (client_id.clone(), client_secret.clone()) }; // Authenticate the client { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { self.audit_log( "refresh_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None, ); return Err(self.error_response("invalid_client", "Client authentication failed")); } } // Validate refresh token (implementation would verify token and get user info) // For now, return a simple refresh token response let new_token_id = Uuid::new_v4().to_string(); let access_token = self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?; let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, refresh_token: Some(new_refresh_token), scope: None, }; self.audit_log( "refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None, ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) } pub fn handle_token_introspection( &self, params: &HashMap, auth_header: Option<&str>, ) -> Result { let token = params .get("token") .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; // Authenticate the client making the introspection request let (client_id, client_secret) = if let Some(auth_header) = auth_header { parse_basic_auth(auth_header).ok_or_else(|| { self.error_response("invalid_client", "Invalid Authorization header") })? } else { return Err(self.error_response("invalid_client", "Client authentication required")); }; { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { return Err(self.error_response("invalid_client", "Client authentication failed")); } } // Look up token in database let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); let db_token = { let db = self.database.lock().unwrap(); db.get_access_token(&token_hash).ok().flatten() }; let response = if let Some(db_token) = db_token { if !db_token.is_revoked && Utc::now() < db_token.expires_at { TokenIntrospectionResponse { active: true, client_id: Some(db_token.client_id.clone()), username: Some(db_token.user_id.clone()), scope: db_token.scope.clone(), exp: Some(db_token.expires_at.timestamp() as u64), iat: Some(db_token.created_at.timestamp() as u64), sub: Some(db_token.user_id), aud: Some(db_token.client_id), iss: Some(self.config.issuer_url.clone()), jti: Some(db_token.token_id), } } else { TokenIntrospectionResponse { active: false, client_id: None, username: None, scope: None, exp: None, iat: None, sub: None, aud: None, iss: None, jti: None, } } } else { TokenIntrospectionResponse { active: false, client_id: None, username: None, scope: None, exp: None, iat: None, sub: None, aud: None, iss: None, jti: None, } }; serde_json::to_string(&response) .map_err(|_| self.error_response("server_error", "Failed to serialize response")) } pub fn handle_token_revocation( &self, params: &HashMap, auth_header: Option<&str>, ) -> Result<(), String> { let token = params .get("token") .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?; // Authenticate the client making the revocation request let (client_id, client_secret) = if let Some(auth_header) = auth_header { parse_basic_auth(auth_header).ok_or_else(|| { self.error_response("invalid_client", "Invalid Authorization header") })? } else { return Err(self.error_response("invalid_client", "Client authentication required")); }; { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { return Err(self.error_response("invalid_client", "Client authentication failed")); } } // Revoke token in database let token_hash = format!("{:x}", Sha256::digest(token.as_bytes())); { let db = self.database.lock().unwrap(); let _ = db.revoke_access_token(&token_hash); // Ignore errors as per RFC 7009 } self.audit_log("token_revoked", Some(&client_id), None, None, true, None); Ok(()) } fn generate_access_token( &self, user_id: &str, client_id: &str, scope: &Option, token_id: &str, ) -> Result { let mut key_manager = self.key_manager.lock().unwrap(); // Check if we need to rotate keys if key_manager.should_rotate() { if let Err(_) = key_manager.rotate_keys() { return Err(self.error_response("server_error", "Key rotation failed")); } } let current_key = key_manager .get_current_key() .ok_or_else(|| self.error_response("server_error", "No signing key available"))?; let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); let claims = Claims { sub: user_id.to_string(), iss: self.config.issuer_url.clone(), aud: client_id.to_string(), exp: now + 3600, iat: now, scope: scope.clone(), jti: Some(token_id.to_string()), }; let mut header = Header::new(Algorithm::RS256); header.kid = Some(current_key.kid.clone()); encode(&header, &claims, ¤t_key.encoding_key) .map_err(|_| self.error_response("server_error", "Failed to generate token")) } fn generate_refresh_token( &self, _client_id: &str, _user_id: &str, _scope: &Option, ) -> Result { // For now, return a simple UUID-based refresh token // In production, this should be a proper JWT or encrypted token Ok(Uuid::new_v4().to_string()) } fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { let db = self.database.lock().unwrap(); let count = db.increment_rate_limit(identifier, endpoint, 1)?; if count > self.config.rate_limit_requests_per_minute as i32 { return Err(anyhow!("Rate limit exceeded")); } Ok(()) } fn audit_log( &self, event_type: &str, client_id: Option<&str>, user_id: Option<&str>, ip_address: Option<&str>, success: bool, details: Option<&str>, ) { if !self.config.enable_audit_logging { return; } let log = DbAuditLog { id: 0, event_type: event_type.to_string(), client_id: client_id.map(|s| s.to_string()), user_id: user_id.map(|s| s.to_string()), ip_address: ip_address.map(|s| s.to_string()), user_agent: None, // Could be passed in from HTTP layer details: details.map(|s| s.to_string()), created_at: Utc::now(), success, }; let db = self.database.lock().unwrap(); let _ = db.create_audit_log(&log); // Ignore errors in audit logging } fn error_response(&self, error: &str, description: &str) -> String { let error_resp = ErrorResponse { error: error.to_string(), error_description: Some(description.to_string()), error_uri: None, }; serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) } // Cleanup expired data pub fn cleanup_expired_data(&self) -> Result<()> { let db = self.database.lock().unwrap(); // Cleanup expired authorization codes let _ = db.cleanup_expired_codes(); // Cleanup expired tokens let _ = db.cleanup_expired_tokens(); // Cleanup old audit logs (keep for 30 days) let _ = db.cleanup_old_audit_logs(30); Ok(()) } }