diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 17:25:40 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 17:25:40 -0600 |
| commit | 9fc51f0a8312d87b65adb661fc1c6757662d9479 (patch) | |
| tree | 1928134f8fe8e6bdb7ecd826e72b7af340f4295e /src | |
| parent | 0ff7c18f2e0e4f72cf6354530329c1c915c6294a (diff) | |
refactor: extract design patterns
Diffstat (limited to 'src')
| -rw-r--r-- | src/domain/dto.rs | 134 | ||||
| -rw-r--r-- | src/domain/mappers.rs | 209 | ||||
| -rw-r--r-- | src/domain/mod.rs | 12 | ||||
| -rw-r--r-- | src/domain/queries.rs | 307 | ||||
| -rw-r--r-- | src/domain/specifications.rs | 194 | ||||
| -rw-r--r-- | src/domain/unit_of_work.rs | 64 |
6 files changed, 919 insertions, 1 deletions
diff --git a/src/domain/dto.rs b/src/domain/dto.rs new file mode 100644 index 0000000..336db61 --- /dev/null +++ b/src/domain/dto.rs @@ -0,0 +1,134 @@ +use serde::{Deserialize, Serialize}; + +/// API DTOs for OAuth2 endpoints - these define the wire format +/// Separate from domain models to allow API versioning without affecting business logic + +#[derive(Debug, Serialize, Deserialize)] +pub struct AuthorizeRequestDto { + pub client_id: String, + pub redirect_uri: String, + pub response_type: String, + pub scope: Option<String>, + pub state: Option<String>, + pub code_challenge: Option<String>, + pub code_challenge_method: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TokenRequestDto { + pub grant_type: String, + pub code: Option<String>, + pub refresh_token: Option<String>, + pub redirect_uri: Option<String>, + pub client_id: Option<String>, + pub client_secret: Option<String>, + pub code_verifier: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TokenResponseDto { + pub access_token: String, + pub token_type: String, + pub expires_in: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ErrorResponseDto { + pub error: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_description: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_uri: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IntrospectionRequestDto { + pub token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_type_hint: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IntrospectionResponseDto { + pub active: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub client_id: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub username: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub scope: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub exp: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub iat: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub sub: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub aud: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub iss: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub jti: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RevocationRequestDto { + pub token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub token_type_hint: Option<String>, +} + +// Conversion traits between DTOs and domain models +impl From<crate::domain::AuthorizationRequest> for AuthorizeRequestDto { + fn from(req: crate::domain::AuthorizationRequest) -> Self { + Self { + client_id: req.client_id, + redirect_uri: req.redirect_uri, + response_type: req.response_type, + scope: req.scope, + state: req.state, + code_challenge: req.code_challenge, + code_challenge_method: req.code_challenge_method, + } + } +} + +impl From<AuthorizeRequestDto> for crate::domain::AuthorizationRequest { + fn from(dto: AuthorizeRequestDto) -> Self { + Self { + client_id: dto.client_id, + redirect_uri: dto.redirect_uri, + response_type: dto.response_type, + scope: dto.scope, + state: dto.state, + code_challenge: dto.code_challenge, + code_challenge_method: dto.code_challenge_method, + } + } +} + +impl From<crate::domain::TokenResult> for TokenResponseDto { + fn from(result: crate::domain::TokenResult) -> Self { + Self { + access_token: result.access_token, + token_type: result.token_type, + expires_in: result.expires_in, + refresh_token: result.refresh_token, + scope: result.scope, + } + } +} + +impl From<crate::domain::OAuthError> for ErrorResponseDto { + fn from(error: crate::domain::OAuthError) -> Self { + Self { + error: error.error_code, + error_description: error.description, + error_uri: error.uri, + } + } +}
\ No newline at end of file diff --git a/src/domain/mappers.rs b/src/domain/mappers.rs new file mode 100644 index 0000000..6efe276 --- /dev/null +++ b/src/domain/mappers.rs @@ -0,0 +1,209 @@ +use crate::database::{DbAccessToken, DbAuthCode, DbAuditLog, DbOAuthClient}; +use crate::domain::models::*; +use anyhow::Result; + +/// Data Mapper pattern - responsible for moving data between domain objects and database +pub trait DataMapper<Domain, Database> { + fn to_domain(&self, db_model: Database) -> Result<Domain>; + fn to_database(&self, domain_model: &Domain) -> Result<Database>; +} + +/// OAuth Client Data Mapper +pub struct OAuthClientMapper; + +impl DataMapper<OAuthClient, DbOAuthClient> for OAuthClientMapper { + fn to_domain(&self, db_client: DbOAuthClient) -> Result<OAuthClient> { + let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?; + let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect(); + let grant_types: Vec<String> = db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect(); + let response_types: Vec<String> = db_client.response_types.split_whitespace().map(|s| s.to_string()).collect(); + + Ok(OAuthClient { + client_id: db_client.client_id, + client_name: db_client.client_name, + redirect_uris, + scopes, + grant_types, + response_types, + is_active: db_client.is_active, + created_at: db_client.created_at, + updated_at: db_client.updated_at, + }) + } + + fn to_database(&self, client: &OAuthClient) -> Result<DbOAuthClient> { + Ok(DbOAuthClient { + id: 0, // Will be set by database + client_id: client.client_id.clone(), + client_secret_hash: String::new(), // Will be set separately + client_name: client.client_name.clone(), + redirect_uris: serde_json::to_string(&client.redirect_uris)?, + scopes: client.scopes.join(" "), + grant_types: client.grant_types.join(" "), + response_types: client.response_types.join(" "), + created_at: client.created_at, + updated_at: client.updated_at, + is_active: client.is_active, + }) + } +} + +/// Authorization Code Data Mapper +pub struct AuthCodeMapper; + +impl DataMapper<AuthorizationCode, DbAuthCode> for AuthCodeMapper { + fn to_domain(&self, db_code: DbAuthCode) -> Result<AuthorizationCode> { + let scopes = db_code.scope + .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect()) + .unwrap_or_default(); + + Ok(AuthorizationCode { + code: db_code.code, + client_id: db_code.client_id, + user_id: db_code.user_id, + redirect_uri: db_code.redirect_uri, + scopes, + expires_at: db_code.expires_at, + created_at: db_code.created_at, + is_used: db_code.is_used, + code_challenge: db_code.code_challenge, + code_challenge_method: db_code.code_challenge_method, + }) + } + + fn to_database(&self, code: &AuthorizationCode) -> Result<DbAuthCode> { + let scope = if code.scopes.is_empty() { + None + } else { + Some(code.scopes.join(" ")) + }; + + Ok(DbAuthCode { + id: 0, // Will be set by database + code: code.code.clone(), + client_id: code.client_id.clone(), + user_id: code.user_id.clone(), + redirect_uri: code.redirect_uri.clone(), + scope, + expires_at: code.expires_at, + created_at: code.created_at, + is_used: code.is_used, + code_challenge: code.code_challenge.clone(), + code_challenge_method: code.code_challenge_method.clone(), + }) + } +} + +/// Access Token Data Mapper +pub struct AccessTokenMapper; + +impl DataMapper<AccessToken, DbAccessToken> for AccessTokenMapper { + fn to_domain(&self, db_token: DbAccessToken) -> Result<AccessToken> { + let scopes = db_token.scope + .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect()) + .unwrap_or_default(); + + Ok(AccessToken { + token_id: db_token.token_id, + client_id: db_token.client_id, + user_id: db_token.user_id, + scopes, + expires_at: db_token.expires_at, + created_at: db_token.created_at, + is_revoked: db_token.is_revoked, + }) + } + + fn to_database(&self, token: &AccessToken) -> Result<DbAccessToken> { + let scope = if token.scopes.is_empty() { + None + } else { + Some(token.scopes.join(" ")) + }; + + Ok(DbAccessToken { + id: 0, // Will be set by database + token_id: token.token_id.clone(), + client_id: token.client_id.clone(), + user_id: token.user_id.clone(), + scope, + expires_at: token.expires_at, + created_at: token.created_at, + is_revoked: token.is_revoked, + token_hash: String::new(), // Will be set by service layer + }) + } +} + +/// Audit Event Data Mapper +pub struct AuditEventMapper; + +impl DataMapper<AuditEvent, DbAuditLog> for AuditEventMapper { + fn to_domain(&self, db_log: DbAuditLog) -> Result<AuditEvent> { + Ok(AuditEvent { + event_type: db_log.event_type, + client_id: db_log.client_id, + user_id: db_log.user_id, + ip_address: db_log.ip_address, + user_agent: db_log.user_agent, + details: db_log.details, + success: db_log.success, + timestamp: db_log.created_at, + }) + } + + fn to_database(&self, event: &AuditEvent) -> Result<DbAuditLog> { + Ok(DbAuditLog { + id: 0, // Will be set by database + event_type: event.event_type.clone(), + client_id: event.client_id.clone(), + user_id: event.user_id.clone(), + ip_address: event.ip_address.clone(), + user_agent: event.user_agent.clone(), + details: event.details.clone(), + created_at: event.timestamp, + success: event.success, + }) + } +} + +/// Registry of all data mappers +pub struct MapperRegistry { + client_mapper: OAuthClientMapper, + auth_code_mapper: AuthCodeMapper, + access_token_mapper: AccessTokenMapper, + audit_event_mapper: AuditEventMapper, +} + +impl MapperRegistry { + pub fn new() -> Self { + Self { + client_mapper: OAuthClientMapper, + auth_code_mapper: AuthCodeMapper, + access_token_mapper: AccessTokenMapper, + audit_event_mapper: AuditEventMapper, + } + } + + pub fn client_mapper(&self) -> &OAuthClientMapper { + &self.client_mapper + } + + pub fn auth_code_mapper(&self) -> &AuthCodeMapper { + &self.auth_code_mapper + } + + pub fn access_token_mapper(&self) -> &AccessTokenMapper { + &self.access_token_mapper + } + + pub fn audit_event_mapper(&self) -> &AuditEventMapper { + &self.audit_event_mapper + } +} + +impl Default for MapperRegistry { + fn default() -> Self { + Self::new() + } +}
\ No newline at end of file diff --git a/src/domain/mod.rs b/src/domain/mod.rs index 168ec01..9a8bfca 100644 --- a/src/domain/mod.rs +++ b/src/domain/mod.rs @@ -1,9 +1,19 @@ pub mod conversions; +pub mod dto; +pub mod mappers; pub mod models; +pub mod queries; pub mod repositories; pub mod services; +pub mod specifications; +pub mod unit_of_work; pub use conversions::*; +pub use dto::*; +pub use mappers::*; pub use models::*; +pub use queries::*; pub use repositories::*; -pub use services::*;
\ No newline at end of file +pub use services::*; +pub use specifications::*; +pub use unit_of_work::*;
\ No newline at end of file diff --git a/src/domain/queries.rs b/src/domain/queries.rs new file mode 100644 index 0000000..d4eb19e --- /dev/null +++ b/src/domain/queries.rs @@ -0,0 +1,307 @@ +use crate::domain::models::*; +use anyhow::Result; +use chrono::{DateTime, Utc}; + +/// Query Object pattern for complex database queries +pub trait Query<T> { + fn execute(&self) -> Result<Vec<T>>; +} + +/// Base query criteria +#[derive(Debug, Clone)] +pub struct QueryCriteria { + pub limit: Option<u32>, + pub offset: Option<u32>, + pub order_by: Option<String>, + pub order_direction: Option<OrderDirection>, +} + +#[derive(Debug, Clone)] +pub enum OrderDirection { + Asc, + Desc, +} + +impl Default for QueryCriteria { + fn default() -> Self { + Self { + limit: Some(100), + offset: None, + order_by: None, + order_direction: Some(OrderDirection::Desc), + } + } +} + +/// Audit Events Query Object +#[derive(Debug, Clone)] +pub struct AuditEventsQuery { + pub criteria: QueryCriteria, + pub client_id: Option<String>, + pub user_id: Option<String>, + pub event_type: Option<String>, + pub success: Option<bool>, + pub date_from: Option<DateTime<Utc>>, + pub date_to: Option<DateTime<Utc>>, + pub ip_address: Option<String>, +} + +impl AuditEventsQuery { + pub fn new() -> Self { + Self { + criteria: QueryCriteria::default(), + client_id: None, + user_id: None, + event_type: None, + success: None, + date_from: None, + date_to: None, + ip_address: None, + } + } + + pub fn for_client(mut self, client_id: &str) -> Self { + self.client_id = Some(client_id.to_string()); + self + } + + pub fn for_user(mut self, user_id: &str) -> Self { + self.user_id = Some(user_id.to_string()); + self + } + + pub fn event_type(mut self, event_type: &str) -> Self { + self.event_type = Some(event_type.to_string()); + self + } + + pub fn successful_only(mut self) -> Self { + self.success = Some(true); + self + } + + pub fn failed_only(mut self) -> Self { + self.success = Some(false); + self + } + + pub fn date_range(mut self, from: DateTime<Utc>, to: DateTime<Utc>) -> Self { + self.date_from = Some(from); + self.date_to = Some(to); + self + } + + pub fn from_ip(mut self, ip_address: &str) -> Self { + self.ip_address = Some(ip_address.to_string()); + self + } + + pub fn limit(mut self, limit: u32) -> Self { + self.criteria.limit = Some(limit); + self + } + + pub fn offset(mut self, offset: u32) -> Self { + self.criteria.offset = Some(offset); + self + } +} + +/// OAuth Clients Query Object +#[derive(Debug, Clone)] +pub struct OAuthClientsQuery { + pub criteria: QueryCriteria, + pub is_active: Option<bool>, + pub client_name_contains: Option<String>, + pub has_scope: Option<String>, + pub grant_type: Option<String>, +} + +impl OAuthClientsQuery { + pub fn new() -> Self { + Self { + criteria: QueryCriteria::default(), + is_active: None, + client_name_contains: None, + has_scope: None, + grant_type: None, + } + } + + pub fn active_only(mut self) -> Self { + self.is_active = Some(true); + self + } + + pub fn inactive_only(mut self) -> Self { + self.is_active = Some(false); + self + } + + pub fn name_contains(mut self, name_part: &str) -> Self { + self.client_name_contains = Some(name_part.to_string()); + self + } + + pub fn with_scope(mut self, scope: &str) -> Self { + self.has_scope = Some(scope.to_string()); + self + } + + pub fn with_grant_type(mut self, grant_type: &str) -> Self { + self.grant_type = Some(grant_type.to_string()); + self + } +} + +/// Token Usage Analytics Query +#[derive(Debug, Clone)] +pub struct TokenUsageQuery { + pub criteria: QueryCriteria, + pub client_id: Option<String>, + pub date_from: Option<DateTime<Utc>>, + pub date_to: Option<DateTime<Utc>>, + pub group_by: TokenUsageGroupBy, +} + +#[derive(Debug, Clone)] +pub enum TokenUsageGroupBy { + Hour, + Day, + Week, + Month, + Client, +} + +impl TokenUsageQuery { + pub fn new(group_by: TokenUsageGroupBy) -> Self { + Self { + criteria: QueryCriteria::default(), + client_id: None, + date_from: None, + date_to: None, + group_by, + } + } + + pub fn for_client(mut self, client_id: &str) -> Self { + self.client_id = Some(client_id.to_string()); + self + } + + pub fn date_range(mut self, from: DateTime<Utc>, to: DateTime<Utc>) -> Self { + self.date_from = Some(from); + self.date_to = Some(to); + self + } +} + +/// Token Usage Statistics Result +#[derive(Debug, Clone)] +pub struct TokenUsageStats { + pub period: String, + pub client_id: Option<String>, + pub token_count: u32, + pub unique_users: u32, + pub success_rate: f64, +} + +/// Failed Authorization Attempts Query +#[derive(Debug, Clone)] +pub struct FailedAuthQuery { + pub criteria: QueryCriteria, + pub client_id: Option<String>, + pub ip_address: Option<String>, + pub date_from: Option<DateTime<Utc>>, + pub date_to: Option<DateTime<Utc>>, + pub min_attempts: u32, +} + +impl FailedAuthQuery { + pub fn new() -> Self { + Self { + criteria: QueryCriteria::default(), + client_id: None, + ip_address: None, + date_from: None, + date_to: None, + min_attempts: 5, // Minimum failed attempts to be considered suspicious + } + } + + pub fn for_client(mut self, client_id: &str) -> Self { + self.client_id = Some(client_id.to_string()); + self + } + + pub fn from_ip(mut self, ip_address: &str) -> Self { + self.ip_address = Some(ip_address.to_string()); + self + } + + pub fn min_attempts(mut self, attempts: u32) -> Self { + self.min_attempts = attempts; + self + } + + pub fn last_24_hours(mut self) -> Self { + let now = Utc::now(); + let yesterday = now - chrono::Duration::hours(24); + self.date_from = Some(yesterday); + self.date_to = Some(now); + self + } +} + +/// Failed Authorization Result +#[derive(Debug, Clone)] +pub struct FailedAuthResult { + pub client_id: Option<String>, + pub ip_address: Option<String>, + pub attempt_count: u32, + pub first_attempt: DateTime<Utc>, + pub last_attempt: DateTime<Utc>, +} + +/// Query executor trait +pub trait QueryExecutor { + fn execute_audit_query(&self, query: &AuditEventsQuery) -> Result<Vec<AuditEvent>>; + fn execute_client_query(&self, query: &OAuthClientsQuery) -> Result<Vec<OAuthClient>>; + fn execute_token_usage_query(&self, query: &TokenUsageQuery) -> Result<Vec<TokenUsageStats>>; + fn execute_failed_auth_query(&self, query: &FailedAuthQuery) -> Result<Vec<FailedAuthResult>>; +} + +/// Predefined queries for common use cases +pub struct CommonQueries; + +impl CommonQueries { + /// Get recent failed login attempts (security monitoring) + pub fn recent_failed_logins() -> FailedAuthQuery { + FailedAuthQuery::new() + .last_24_hours() + .min_attempts(3) + } + + /// Get audit trail for a specific client + pub fn client_audit_trail(client_id: &str) -> AuditEventsQuery { + AuditEventsQuery::new() + .for_client(client_id) + .limit(1000) + } + + /// Get token usage statistics for the last 30 days + pub fn monthly_token_usage() -> TokenUsageQuery { + let now = Utc::now(); + let thirty_days_ago = now - chrono::Duration::days(30); + + TokenUsageQuery::new(TokenUsageGroupBy::Day) + .date_range(thirty_days_ago, now) + } + + /// Get all active clients with OpenID scope + pub fn openid_clients() -> OAuthClientsQuery { + OAuthClientsQuery::new() + .active_only() + .with_scope("openid") + } +}
\ No newline at end of file diff --git a/src/domain/specifications.rs b/src/domain/specifications.rs new file mode 100644 index 0000000..3237d1b --- /dev/null +++ b/src/domain/specifications.rs @@ -0,0 +1,194 @@ +use crate::domain::models::*; + +/// Specification pattern for encapsulating business rules +pub trait Specification<T> { + fn is_satisfied_by(&self, candidate: &T) -> bool; + fn reason_for_failure(&self, candidate: &T) -> Option<String>; +} + +/// Composite specifications +pub struct AndSpecification<T> { + left: Box<dyn Specification<T>>, + right: Box<dyn Specification<T>>, +} + +impl<T> AndSpecification<T> { + pub fn new(left: Box<dyn Specification<T>>, right: Box<dyn Specification<T>>) -> Self { + Self { left, right } + } +} + +impl<T> Specification<T> for AndSpecification<T> { + fn is_satisfied_by(&self, candidate: &T) -> bool { + self.left.is_satisfied_by(candidate) && self.right.is_satisfied_by(candidate) + } + + fn reason_for_failure(&self, candidate: &T) -> Option<String> { + if !self.left.is_satisfied_by(candidate) { + self.left.reason_for_failure(candidate) + } else if !self.right.is_satisfied_by(candidate) { + self.right.reason_for_failure(candidate) + } else { + None + } + } +} + +/// OAuth2 Client Specifications +pub struct ActiveClientSpecification; +impl Specification<OAuthClient> for ActiveClientSpecification { + fn is_satisfied_by(&self, client: &OAuthClient) -> bool { + client.is_active + } + + fn reason_for_failure(&self, _client: &OAuthClient) -> Option<String> { + Some("Client is not active".to_string()) + } +} + +pub struct ValidRedirectUriSpecification { + redirect_uri: String, +} + +impl ValidRedirectUriSpecification { + pub fn new(redirect_uri: String) -> Self { + Self { redirect_uri } + } +} + +impl Specification<OAuthClient> for ValidRedirectUriSpecification { + fn is_satisfied_by(&self, client: &OAuthClient) -> bool { + client.redirect_uris.contains(&self.redirect_uri) + } + + fn reason_for_failure(&self, _client: &OAuthClient) -> Option<String> { + Some(format!("Invalid redirect_uri: {}", self.redirect_uri)) + } +} + +pub struct SupportedScopesSpecification { + requested_scopes: Vec<String>, +} + +impl SupportedScopesSpecification { + pub fn new(requested_scopes: Vec<String>) -> Self { + Self { requested_scopes } + } +} + +impl Specification<OAuthClient> for SupportedScopesSpecification { + fn is_satisfied_by(&self, client: &OAuthClient) -> bool { + self.requested_scopes.iter().all(|scope| client.scopes.contains(scope)) + } + + fn reason_for_failure(&self, client: &OAuthClient) -> Option<String> { + let unsupported: Vec<_> = self.requested_scopes.iter() + .filter(|scope| !client.scopes.contains(scope)) + .cloned() + .collect(); + + if !unsupported.is_empty() { + Some(format!("Unsupported scopes: {}", unsupported.join(", "))) + } else { + None + } + } +} + +/// Authorization Code Specifications +pub struct UnusedAuthCodeSpecification; +impl Specification<AuthorizationCode> for UnusedAuthCodeSpecification { + fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { + !code.is_used + } + + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { + Some("Authorization code has already been used".to_string()) + } +} + +pub struct ValidAuthCodeSpecification; +impl Specification<AuthorizationCode> for ValidAuthCodeSpecification { + fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { + chrono::Utc::now() < code.expires_at + } + + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { + Some("Authorization code has expired".to_string()) + } +} + +pub struct MatchingClientSpecification { + client_id: String, +} + +impl MatchingClientSpecification { + pub fn new(client_id: String) -> Self { + Self { client_id } + } +} + +impl Specification<AuthorizationCode> for MatchingClientSpecification { + fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { + code.client_id == self.client_id + } + + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { + Some("Client ID mismatch".to_string()) + } +} + +/// PKCE Specifications +pub struct ValidPkceSpecification { + code_verifier: String, +} + +impl ValidPkceSpecification { + pub fn new(code_verifier: String) -> Self { + Self { code_verifier } + } +} + +impl Specification<AuthorizationCode> for ValidPkceSpecification { + fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { + if let Some(challenge) = &code.code_challenge { + let method = code.code_challenge_method.as_deref().unwrap_or("plain"); + crate::oauth::pkce::verify_code_challenge(&self.code_verifier, challenge, + &crate::oauth::pkce::CodeChallengeMethod::from_str(method).unwrap_or(crate::oauth::pkce::CodeChallengeMethod::Plain) + ).is_ok() + } else { + true // No PKCE required + } + } + + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { + Some("PKCE verification failed".to_string()) + } +} + +/// Access Token Specifications +pub struct ValidTokenSpecification; +impl Specification<AccessToken> for ValidTokenSpecification { + fn is_satisfied_by(&self, token: &AccessToken) -> bool { + !token.is_revoked && chrono::Utc::now() < token.expires_at + } + + fn reason_for_failure(&self, token: &AccessToken) -> Option<String> { + if token.is_revoked { + Some("Token has been revoked".to_string()) + } else if chrono::Utc::now() >= token.expires_at { + Some("Token has expired".to_string()) + } else { + None + } + } +} + +/// Helper trait for chaining specifications +pub trait SpecificationExt<T>: Specification<T> + Sized + 'static { + fn and(self, other: impl Specification<T> + 'static) -> AndSpecification<T> { + AndSpecification::new(Box::new(self), Box::new(other)) + } +} + +impl<T, S: Specification<T> + 'static> SpecificationExt<T> for S {}
\ No newline at end of file diff --git a/src/domain/unit_of_work.rs b/src/domain/unit_of_work.rs new file mode 100644 index 0000000..db8294a --- /dev/null +++ b/src/domain/unit_of_work.rs @@ -0,0 +1,64 @@ +use anyhow::Result; +use std::sync::Arc; + +/// Unit of Work pattern for managing transactional boundaries +pub trait UnitOfWork: Send + Sync { + /// Begin a new transaction + fn begin(&self) -> Result<Box<dyn Transaction>>; +} + +/// Transaction interface for atomic operations +pub trait Transaction: Send + Sync { + /// Commit all changes in this transaction + fn commit(self: Box<Self>) -> Result<()>; + + /// Rollback all changes in this transaction + fn rollback(self: Box<Self>) -> Result<()>; + + /// Get repositories within this transaction context + fn client_repository(&self) -> Arc<dyn crate::domain::DomainClientRepository>; + fn auth_code_repository(&self) -> Arc<dyn crate::domain::DomainAuthCodeRepository>; + fn token_repository(&self) -> Arc<dyn crate::domain::DomainTokenRepository>; + fn audit_repository(&self) -> Arc<dyn crate::domain::DomainAuditRepository>; +} + +/// OAuth2-specific transactional operations +pub struct OAuthUnitOfWork { + uow: Arc<dyn UnitOfWork>, +} + +impl OAuthUnitOfWork { + pub fn new(uow: Arc<dyn UnitOfWork>) -> Self { + Self { uow } + } + + /// Execute OAuth2 authorization code exchange atomically + pub fn exchange_authorization_code<F>(&self, operation: F) -> Result<()> + where + F: FnOnce(&dyn Transaction) -> Result<()>, + { + let tx = self.uow.begin()?; + match operation(tx.as_ref()) { + Ok(_) => tx.commit(), + Err(e) => { + let _ = tx.rollback(); // Log but don't override original error + Err(e) + } + } + } + + /// Execute token refresh atomically + pub fn refresh_tokens<F>(&self, operation: F) -> Result<()> + where + F: FnOnce(&dyn Transaction) -> Result<()>, + { + let tx = self.uow.begin()?; + match operation(tx.as_ref()) { + Ok(_) => tx.commit(), + Err(e) => { + let _ = tx.rollback(); + Err(e) + } + } + } +}
\ No newline at end of file |
