diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 20:20:04 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 20:20:04 -0600 |
| commit | c28b7088b6fad045060a52b6e1a2249e876090e3 (patch) | |
| tree | a8fc26fd5365d4988d9206b32d94f51047cf0bcc | |
| parent | 19ca22e604f9bcdf6b25f973f81b2486b0dcb789 (diff) | |
refactor: extract domain model
| -rw-r--r-- | Cargo.lock | 49 | ||||
| -rw-r--r-- | Cargo.toml | 11 | ||||
| -rw-r--r-- | migrations/20241201000000_initial_schema.sql (renamed from migrations/001_initial_schema.sql) | 0 | ||||
| -rw-r--r-- | spec/integration/server_spec.rb | 4 | ||||
| -rw-r--r-- | src/bin/generate_migration.rs | 22 | ||||
| -rw-r--r-- | src/bin/migrate.rs | 8 | ||||
| -rw-r--r-- | src/bin/test.rs | 154 | ||||
| -rw-r--r-- | src/database.rs | 2 | ||||
| -rw-r--r-- | src/domain/conversions.rs | 40 | ||||
| -rw-r--r-- | src/domain/dto.rs | 2 | ||||
| -rw-r--r-- | src/domain/mappers.rs | 40 | ||||
| -rw-r--r-- | src/domain/mod.rs | 2 | ||||
| -rw-r--r-- | src/domain/models.rs | 8 | ||||
| -rw-r--r-- | src/domain/queries.rs | 19 | ||||
| -rw-r--r-- | src/domain/repositories.rs | 9 | ||||
| -rw-r--r-- | src/domain/services.rs | 70 | ||||
| -rw-r--r-- | src/domain/specifications.rs | 40 | ||||
| -rw-r--r-- | src/domain/unit_of_work.rs | 10 | ||||
| -rw-r--r-- | src/lib.rs | 1 | ||||
| -rw-r--r-- | src/migration_discovery.rs | 296 | ||||
| -rw-r--r-- | src/migrations.rs | 103 |
21 files changed, 763 insertions, 127 deletions
@@ -220,6 +220,16 @@ dependencies = [ ] [[package]] +name = "errno" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] name = "fallible-iterator" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -232,6 +242,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" [[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] name = "form_urlencoded" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -494,6 +510,12 @@ dependencies = [ ] [[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + +[[package]] name = "litemap" version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -843,6 +865,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f" [[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + +[[package]] name = "rustversion" version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -993,6 +1028,7 @@ dependencies = [ "serde_json", "sha2", "subtle", + "tempfile", "tokio", "url", "urlencoding", @@ -1028,6 +1064,19 @@ dependencies = [ ] [[package]] +name = "tempfile" +version = "3.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" +dependencies = [ + "fastrand", + "getrandom 0.3.3", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + +[[package]] name = "thiserror" version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -15,6 +15,14 @@ path = "src/bin/migrate.rs" name = "debug" path = "src/bin/debug.rs" +[[bin]] +name = "generate_migration" +path = "src/bin/generate_migration.rs" + +[[bin]] +name = "test" +path = "src/bin/test.rs" + [dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -31,3 +39,6 @@ chrono = { version = "0.4", features = ["serde"] } tokio = { version = "1.0", features = ["full"] } anyhow = "1.0" subtle = "2.4" + +[dev-dependencies] +tempfile = "3.0" diff --git a/migrations/001_initial_schema.sql b/migrations/20241201000000_initial_schema.sql index 8796157..8796157 100644 --- a/migrations/001_initial_schema.sql +++ b/migrations/20241201000000_initial_schema.sql diff --git a/spec/integration/server_spec.rb b/spec/integration/server_spec.rb index 229203c..36cb81c 100644 --- a/spec/integration/server_spec.rb +++ b/spec/integration/server_spec.rb @@ -62,6 +62,10 @@ RSpec.describe "Server" do end end + describe "GET /authorize" do + pending + end + # https://datatracker.ietf.org/doc/html/rfc8693#section-2.3 describe "POST /token" do pending diff --git a/src/bin/generate_migration.rs b/src/bin/generate_migration.rs new file mode 100644 index 0000000..c72db66 --- /dev/null +++ b/src/bin/generate_migration.rs @@ -0,0 +1,22 @@ +use std::env; +use sts::migration_discovery::{generate_migration_filename, generate_migration_timestamp}; + +fn main() { + let args: Vec<String> = env::args().collect(); + + if args.len() < 2 { + eprintln!("Usage: cargo run --bin generate_migration <migration_name>"); + eprintln!("Example: cargo run --bin generate_migration add_users_table"); + return; + } + + let migration_name = &args[1]; + let filename = generate_migration_filename(migration_name); + let timestamp = generate_migration_timestamp(); + + println!("Generated migration filename: {}", filename); + println!("Timestamp: {}", timestamp); + println!(""); + println!("To create the migration file:"); + println!("touch migrations/{}", filename); +} diff --git a/src/bin/migrate.rs b/src/bin/migrate.rs index 9a0bab9..fbf5183 100644 --- a/src/bin/migrate.rs +++ b/src/bin/migrate.rs @@ -13,7 +13,7 @@ fn main() -> Result<()> { let config = Config::from_env(); let conn = Connection::open(&config.database_path)?; - let runner = MigrationRunner::new(&conn); + let runner = MigrationRunner::new(&conn)?; match args[1].as_str() { "up" => { @@ -29,7 +29,7 @@ fn main() -> Result<()> { eprintln!("Usage: cargo run --bin migrate rollback <version>"); return Ok(()); } - let version: i32 = args[2] + let version: i64 = args[2] .parse() .map_err(|_| anyhow::anyhow!("Invalid version number: {}", args[2]))?; runner.rollback_to_version(version)?; @@ -50,7 +50,7 @@ fn print_usage() { println!("Usage:"); println!(" cargo run --bin migrate up # Run pending migrations"); println!(" cargo run --bin migrate status # Show migration status"); - println!(" cargo run --bin migrate rollback <version> # Rollback to version"); + println!(" cargo run --bin migrate rollback <timestamp> # Rollback to timestamp"); println!(""); println!("Environment Variables:"); println!(" DATABASE_PATH Path to SQLite database (default: oauth.db)"); @@ -58,5 +58,5 @@ fn print_usage() { println!("Examples:"); println!(" cargo run --bin migrate up"); println!(" cargo run --bin migrate status"); - println!(" cargo run --bin migrate rollback 0"); + println!(" cargo run --bin migrate rollback 20231201120000"); } diff --git a/src/bin/test.rs b/src/bin/test.rs new file mode 100644 index 0000000..d2704c4 --- /dev/null +++ b/src/bin/test.rs @@ -0,0 +1,154 @@ +use std::env; +use std::process::{Command, exit}; + +fn main() { + let args: Vec<String> = env::args().collect(); + + if args.len() < 2 { + print_usage(); + return; + } + + match args[1].as_str() { + "unit" => { + println!("Running unit tests..."); + run_command(&["cargo", "test", "--lib"]); + } + "integration" => { + println!("Running integration tests..."); + run_command_allow_failure(&["bundle", "exec", "rspec"]); + } + "all" => { + println!("Running all tests..."); + println!("==================="); + println!(); + + println!("1. Running Rust unit tests..."); + run_command(&["cargo", "test"]); + println!(); + + println!("2. Running Ruby integration tests..."); + run_command_allow_failure(&["bundle", "exec", "rspec"]); + } + "watch" => { + println!("Running tests in watch mode..."); + if args.len() > 2 && args[2] == "integration" { + run_command(&["bundle", "exec", "guard"]); + } else { + run_command(&["cargo", "watch", "-x", "test"]); + } + } + "coverage" => { + println!("Running tests with coverage..."); + run_command(&["cargo", "tarpaulin", "--out", "Html"]); + } + "check" => { + println!("Running cargo check..."); + run_command(&["cargo", "check"]); + } + "lint" => { + println!("Running linting..."); + run_command(&["cargo", "clippy", "--", "-D", "warnings"]); + } + "fmt" => { + println!("Running code formatting..."); + run_command(&["cargo", "fmt"]); + } + "clean" => { + println!("Cleaning test artifacts..."); + run_command(&["cargo", "clean"]); + run_command(&["rm", "-f", "oauth.db"]); + run_command(&["rm", "-f", "test.db"]); + } + "server" => { + println!("Starting test server..."); + run_command(&["cargo", "run", "--bin", "sts"]); + } + "migrate" => { + println!("Running migrations for tests..."); + run_command(&["cargo", "run", "--bin", "migrate", "up"]); + } + "reset" => { + println!("Resetting test environment..."); + run_command(&["rm", "-f", "oauth.db"]); + run_command(&["cargo", "run", "--bin", "migrate", "up"]); + } + _ => { + eprintln!("Error: Unknown command '{}'", args[1]); + print_usage(); + exit(1); + } + } +} + +fn run_command(cmd: &[&str]) { + let mut command = Command::new(cmd[0]); + if cmd.len() > 1 { + command.args(&cmd[1..]); + } + + let status = command.status().unwrap_or_else(|err| { + eprintln!("Failed to execute command '{:?}': {}", cmd, err); + exit(1); + }); + + if !status.success() { + eprintln!( + "Command '{:?}' failed with exit code: {:?}", + cmd, + status.code() + ); + exit(1); + } +} + +fn run_command_allow_failure(cmd: &[&str]) { + let mut command = Command::new(cmd[0]); + if cmd.len() > 1 { + command.args(&cmd[1..]); + } + + let status = command.status().unwrap_or_else(|err| { + eprintln!("Failed to execute command '{:?}': {}", cmd, err); + exit(1); + }); + + if !status.success() { + eprintln!( + "Command '{:?}' completed with exit code: {:?}", + cmd, + status.code() + ); + // Don't exit, just report the failure + } +} + +fn print_usage() { + println!("OAuth2 STS Test Runner"); + println!("====================="); + println!(); + println!("Usage:"); + println!(" cargo run --bin test <command>"); + println!(); + println!("Commands:"); + println!(" unit Run Rust unit tests only"); + println!(" integration Run Ruby integration tests only"); + println!(" all Run all tests (unit + integration)"); + println!(" watch Run tests in watch mode"); + println!(" watch integration Run integration tests in watch mode"); + println!(" coverage Run tests with coverage report"); + println!(" check Run cargo check"); + println!(" lint Run clippy linting"); + println!(" fmt Run code formatting"); + println!(" clean Clean test artifacts and databases"); + println!(" server Start test server"); + println!(" migrate Run database migrations"); + println!(" reset Reset test environment (clean DB + migrate)"); + println!(); + println!("Examples:"); + println!(" cargo run --bin test unit"); + println!(" cargo run --bin test all"); + println!(" cargo run --bin test watch"); + println!(" cargo run --bin test coverage"); + println!(" cargo run --bin test reset"); +} diff --git a/src/database.rs b/src/database.rs index a91579b..178eee3 100644 --- a/src/database.rs +++ b/src/database.rs @@ -116,7 +116,7 @@ impl Database { fn run_migrations(&self) -> Result<()> { // Use the migration system instead of duplicated schema - let migration_runner = crate::migrations::MigrationRunner::new(&self.conn); + let migration_runner = crate::migrations::MigrationRunner::new(&self.conn)?; migration_runner.run_migrations()?; Ok(()) } diff --git a/src/domain/conversions.rs b/src/domain/conversions.rs index 53e6062..13a1b9b 100644 --- a/src/domain/conversions.rs +++ b/src/domain/conversions.rs @@ -1,4 +1,4 @@ -use crate::database::{DbAccessToken, DbAuthCode, DbAuditLog, DbOAuthClient}; +use crate::database::{DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient}; use crate::domain::models::*; use anyhow::Result; @@ -18,9 +18,21 @@ pub trait ToDb<T> { impl FromDb<DbOAuthClient> for OAuthClient { fn from_db(db_client: DbOAuthClient) -> Result<Self> { let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?; - let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect(); - let grant_types: Vec<String> = db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect(); - let response_types: Vec<String> = db_client.response_types.split_whitespace().map(|s| s.to_string()).collect(); + let scopes: Vec<String> = db_client + .scopes + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + let grant_types: Vec<String> = db_client + .grant_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + let response_types: Vec<String> = db_client + .response_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(); Ok(OAuthClient { client_id: db_client.client_id, @@ -57,8 +69,13 @@ impl ToDb<DbOAuthClient> for OAuthClient { // Authorization Code conversions impl FromDb<DbAuthCode> for AuthorizationCode { fn from_db(db_code: DbAuthCode) -> Result<Self> { - let scopes = db_code.scope - .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect()) + let scopes = db_code + .scope + .map(|s| { + s.split_whitespace() + .map(|scope| scope.to_string()) + .collect() + }) .unwrap_or_default(); Ok(AuthorizationCode { @@ -103,8 +120,13 @@ impl ToDb<DbAuthCode> for AuthorizationCode { // Access Token conversions impl FromDb<DbAccessToken> for AccessToken { fn from_db(db_token: DbAccessToken) -> Result<Self> { - let scopes = db_token.scope - .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect()) + let scopes = db_token + .scope + .map(|s| { + s.split_whitespace() + .map(|scope| scope.to_string()) + .collect() + }) .unwrap_or_default(); Ok(AccessToken { @@ -171,4 +193,4 @@ impl ToDb<DbAuditLog> for AuditEvent { success: self.success, }) } -}
\ No newline at end of file +} diff --git a/src/domain/dto.rs b/src/domain/dto.rs index 336db61..1c342bc 100644 --- a/src/domain/dto.rs +++ b/src/domain/dto.rs @@ -131,4 +131,4 @@ impl From<crate::domain::OAuthError> for ErrorResponseDto { error_uri: error.uri, } } -}
\ No newline at end of file +} diff --git a/src/domain/mappers.rs b/src/domain/mappers.rs index 6efe276..405b08b 100644 --- a/src/domain/mappers.rs +++ b/src/domain/mappers.rs @@ -1,4 +1,4 @@ -use crate::database::{DbAccessToken, DbAuthCode, DbAuditLog, DbOAuthClient}; +use crate::database::{DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient}; use crate::domain::models::*; use anyhow::Result; @@ -14,9 +14,21 @@ pub struct OAuthClientMapper; impl DataMapper<OAuthClient, DbOAuthClient> for OAuthClientMapper { fn to_domain(&self, db_client: DbOAuthClient) -> Result<OAuthClient> { let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?; - let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect(); - let grant_types: Vec<String> = db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect(); - let response_types: Vec<String> = db_client.response_types.split_whitespace().map(|s| s.to_string()).collect(); + let scopes: Vec<String> = db_client + .scopes + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + let grant_types: Vec<String> = db_client + .grant_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + let response_types: Vec<String> = db_client + .response_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(); Ok(OAuthClient { client_id: db_client.client_id, @@ -53,8 +65,13 @@ pub struct AuthCodeMapper; impl DataMapper<AuthorizationCode, DbAuthCode> for AuthCodeMapper { fn to_domain(&self, db_code: DbAuthCode) -> Result<AuthorizationCode> { - let scopes = db_code.scope - .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect()) + let scopes = db_code + .scope + .map(|s| { + s.split_whitespace() + .map(|scope| scope.to_string()) + .collect() + }) .unwrap_or_default(); Ok(AuthorizationCode { @@ -99,8 +116,13 @@ pub struct AccessTokenMapper; impl DataMapper<AccessToken, DbAccessToken> for AccessTokenMapper { fn to_domain(&self, db_token: DbAccessToken) -> Result<AccessToken> { - let scopes = db_token.scope - .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect()) + let scopes = db_token + .scope + .map(|s| { + s.split_whitespace() + .map(|scope| scope.to_string()) + .collect() + }) .unwrap_or_default(); Ok(AccessToken { @@ -206,4 +228,4 @@ impl Default for MapperRegistry { fn default() -> Self { Self::new() } -}
\ No newline at end of file +} diff --git a/src/domain/mod.rs b/src/domain/mod.rs index 9a8bfca..7ba3b00 100644 --- a/src/domain/mod.rs +++ b/src/domain/mod.rs @@ -16,4 +16,4 @@ pub use queries::*; pub use repositories::*; pub use services::*; pub use specifications::*; -pub use unit_of_work::*;
\ No newline at end of file +pub use unit_of_work::*; diff --git a/src/domain/models.rs b/src/domain/models.rs index 85b554f..26e6df3 100644 --- a/src/domain/models.rs +++ b/src/domain/models.rs @@ -108,9 +108,9 @@ pub struct AuthorizationRequest { #[derive(Debug, Clone, PartialEq)] pub struct TokenRequest { pub grant_type: String, - pub code: Option<String>, // For authorization_code grant - pub refresh_token: Option<String>, // For refresh_token grant - pub redirect_uri: Option<String>, // For authorization_code grant + pub code: Option<String>, // For authorization_code grant + pub refresh_token: Option<String>, // For refresh_token grant + pub redirect_uri: Option<String>, // For authorization_code grant pub client_id: String, pub client_secret: Option<String>, // PKCE @@ -220,4 +220,4 @@ impl Scope { description: Some("Access to user email address".to_string()), } } -}
\ No newline at end of file +} diff --git a/src/domain/queries.rs b/src/domain/queries.rs index d4eb19e..88fb480 100644 --- a/src/domain/queries.rs +++ b/src/domain/queries.rs @@ -277,31 +277,24 @@ pub struct CommonQueries; impl CommonQueries { /// Get recent failed login attempts (security monitoring) pub fn recent_failed_logins() -> FailedAuthQuery { - FailedAuthQuery::new() - .last_24_hours() - .min_attempts(3) + FailedAuthQuery::new().last_24_hours().min_attempts(3) } /// Get audit trail for a specific client pub fn client_audit_trail(client_id: &str) -> AuditEventsQuery { - AuditEventsQuery::new() - .for_client(client_id) - .limit(1000) + AuditEventsQuery::new().for_client(client_id).limit(1000) } /// Get token usage statistics for the last 30 days pub fn monthly_token_usage() -> TokenUsageQuery { let now = Utc::now(); let thirty_days_ago = now - chrono::Duration::days(30); - - TokenUsageQuery::new(TokenUsageGroupBy::Day) - .date_range(thirty_days_ago, now) + + TokenUsageQuery::new(TokenUsageGroupBy::Day).date_range(thirty_days_ago, now) } /// Get all active clients with OpenID scope pub fn openid_clients() -> OAuthClientsQuery { - OAuthClientsQuery::new() - .active_only() - .with_scope("openid") + OAuthClientsQuery::new().active_only().with_scope("openid") } -}
\ No newline at end of file +} diff --git a/src/domain/repositories.rs b/src/domain/repositories.rs index 1aa4f33..3622373 100644 --- a/src/domain/repositories.rs +++ b/src/domain/repositories.rs @@ -36,6 +36,11 @@ pub trait DomainAuditRepository: Send + Sync { /// Domain-focused repository for rate limiting pub trait DomainRateRepository: Send + Sync { fn check_rate_limit(&self, limit: &RateLimit) -> Result<u32>; // Returns current count - fn increment_rate_limit(&self, identifier: &str, endpoint: &str, window_minutes: u32) -> Result<u32>; + fn increment_rate_limit( + &self, + identifier: &str, + endpoint: &str, + window_minutes: u32, + ) -> Result<u32>; fn cleanup_old_rate_limits(&self) -> Result<()>; -}
\ No newline at end of file +} diff --git a/src/domain/services.rs b/src/domain/services.rs index 2c23cdc..0e22ddb 100644 --- a/src/domain/services.rs +++ b/src/domain/services.rs @@ -3,10 +3,22 @@ use anyhow::Result; /// Domain service for OAuth2 authorization flow pub trait AuthorizationService: Send + Sync { - fn authorize(&self, request: &AuthorizationRequest, user: &User) -> Result<AuthorizationResult, OAuthError>; + fn authorize( + &self, + request: &AuthorizationRequest, + user: &User, + ) -> Result<AuthorizationResult, OAuthError>; fn validate_client(&self, client_id: &str) -> Result<OAuthClient, OAuthError>; - fn validate_redirect_uri(&self, client: &OAuthClient, redirect_uri: &str) -> Result<(), OAuthError>; - fn validate_scopes(&self, client: &OAuthClient, requested_scopes: &[String]) -> Result<Vec<String>, OAuthError>; + fn validate_redirect_uri( + &self, + client: &OAuthClient, + redirect_uri: &str, + ) -> Result<(), OAuthError>; + fn validate_scopes( + &self, + client: &OAuthClient, + requested_scopes: &[String], + ) -> Result<Vec<String>, OAuthError>; } /// Domain service for OAuth2 token operations @@ -23,28 +35,59 @@ pub trait ClientService: Send + Sync { fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>; fn update_client(&self, client: &OAuthClient) -> Result<()>; fn delete_client(&self, client_id: &str) -> Result<()>; - fn authenticate_client(&self, client_id: &str, client_secret: &str) -> Result<OAuthClient, OAuthError>; + fn authenticate_client( + &self, + client_id: &str, + client_secret: &str, + ) -> Result<OAuthClient, OAuthError>; } /// Domain service for user management pub trait UserService: Send + Sync { fn get_user(&self, user_id: &str) -> Result<Option<User>>; fn authenticate_user(&self, username: &str, password: &str) -> Result<User, OAuthError>; - fn is_user_authorized(&self, user: &User, client: &OAuthClient, scopes: &[String]) -> Result<bool>; + fn is_user_authorized( + &self, + user: &User, + client: &OAuthClient, + scopes: &[String], + ) -> Result<bool>; } /// Domain service for audit logging pub trait AuditService: Send + Sync { - fn log_authorization_attempt(&self, request: &AuthorizationRequest, user: Option<&User>, success: bool, ip_address: Option<&str>) -> Result<()>; - fn log_token_request(&self, request: &TokenRequest, success: bool, ip_address: Option<&str>) -> Result<()>; - fn log_token_introspection(&self, token_hash: &str, client_id: &str, success: bool) -> Result<()>; + fn log_authorization_attempt( + &self, + request: &AuthorizationRequest, + user: Option<&User>, + success: bool, + ip_address: Option<&str>, + ) -> Result<()>; + fn log_token_request( + &self, + request: &TokenRequest, + success: bool, + ip_address: Option<&str>, + ) -> Result<()>; + fn log_token_introspection( + &self, + token_hash: &str, + client_id: &str, + success: bool, + ) -> Result<()>; fn log_token_revocation(&self, token_hash: &str, client_id: &str, success: bool) -> Result<()>; } /// Domain service for rate limiting pub trait RateLimitService: Send + Sync { fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<(), OAuthError>; - fn is_rate_limited(&self, identifier: &str, endpoint: &str, max_requests: u32, window_minutes: u32) -> Result<bool>; + fn is_rate_limited( + &self, + identifier: &str, + endpoint: &str, + max_requests: u32, + window_minutes: u32, + ) -> Result<bool>; } /// Domain service for PKCE operations @@ -57,7 +100,12 @@ pub trait PkceService: Send + Sync { /// Domain service for JWT operations pub trait JwtService: Send + Sync { fn generate_access_token(&self, claims: &TokenClaims) -> Result<String>; - fn generate_refresh_token(&self, client_id: &str, user_id: &str, scopes: &[String]) -> Result<String>; + fn generate_refresh_token( + &self, + client_id: &str, + user_id: &str, + scopes: &[String], + ) -> Result<String>; fn validate_token(&self, token: &str) -> Result<TokenClaims>; fn get_jwks(&self) -> Result<String>; // JSON Web Key Set -}
\ No newline at end of file +} diff --git a/src/domain/specifications.rs b/src/domain/specifications.rs index 3237d1b..76aafcc 100644 --- a/src/domain/specifications.rs +++ b/src/domain/specifications.rs @@ -22,7 +22,7 @@ impl<T> Specification<T> for AndSpecification<T> { fn is_satisfied_by(&self, candidate: &T) -> bool { self.left.is_satisfied_by(candidate) && self.right.is_satisfied_by(candidate) } - + fn reason_for_failure(&self, candidate: &T) -> Option<String> { if !self.left.is_satisfied_by(candidate) { self.left.reason_for_failure(candidate) @@ -40,7 +40,7 @@ impl Specification<OAuthClient> for ActiveClientSpecification { fn is_satisfied_by(&self, client: &OAuthClient) -> bool { client.is_active } - + fn reason_for_failure(&self, _client: &OAuthClient) -> Option<String> { Some("Client is not active".to_string()) } @@ -60,7 +60,7 @@ impl Specification<OAuthClient> for ValidRedirectUriSpecification { fn is_satisfied_by(&self, client: &OAuthClient) -> bool { client.redirect_uris.contains(&self.redirect_uri) } - + fn reason_for_failure(&self, _client: &OAuthClient) -> Option<String> { Some(format!("Invalid redirect_uri: {}", self.redirect_uri)) } @@ -78,15 +78,19 @@ impl SupportedScopesSpecification { impl Specification<OAuthClient> for SupportedScopesSpecification { fn is_satisfied_by(&self, client: &OAuthClient) -> bool { - self.requested_scopes.iter().all(|scope| client.scopes.contains(scope)) + self.requested_scopes + .iter() + .all(|scope| client.scopes.contains(scope)) } - + fn reason_for_failure(&self, client: &OAuthClient) -> Option<String> { - let unsupported: Vec<_> = self.requested_scopes.iter() + let unsupported: Vec<_> = self + .requested_scopes + .iter() .filter(|scope| !client.scopes.contains(scope)) .cloned() .collect(); - + if !unsupported.is_empty() { Some(format!("Unsupported scopes: {}", unsupported.join(", "))) } else { @@ -101,7 +105,7 @@ impl Specification<AuthorizationCode> for UnusedAuthCodeSpecification { fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { !code.is_used } - + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { Some("Authorization code has already been used".to_string()) } @@ -112,7 +116,7 @@ impl Specification<AuthorizationCode> for ValidAuthCodeSpecification { fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { chrono::Utc::now() < code.expires_at } - + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { Some("Authorization code has expired".to_string()) } @@ -132,7 +136,7 @@ impl Specification<AuthorizationCode> for MatchingClientSpecification { fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { code.client_id == self.client_id } - + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { Some("Client ID mismatch".to_string()) } @@ -153,14 +157,18 @@ impl Specification<AuthorizationCode> for ValidPkceSpecification { fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool { if let Some(challenge) = &code.code_challenge { let method = code.code_challenge_method.as_deref().unwrap_or("plain"); - crate::oauth::pkce::verify_code_challenge(&self.code_verifier, challenge, - &crate::oauth::pkce::CodeChallengeMethod::from_str(method).unwrap_or(crate::oauth::pkce::CodeChallengeMethod::Plain) - ).is_ok() + crate::oauth::pkce::verify_code_challenge( + &self.code_verifier, + challenge, + &crate::oauth::pkce::CodeChallengeMethod::from_str(method) + .unwrap_or(crate::oauth::pkce::CodeChallengeMethod::Plain), + ) + .is_ok() } else { true // No PKCE required } } - + fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> { Some("PKCE verification failed".to_string()) } @@ -172,7 +180,7 @@ impl Specification<AccessToken> for ValidTokenSpecification { fn is_satisfied_by(&self, token: &AccessToken) -> bool { !token.is_revoked && chrono::Utc::now() < token.expires_at } - + fn reason_for_failure(&self, token: &AccessToken) -> Option<String> { if token.is_revoked { Some("Token has been revoked".to_string()) @@ -191,4 +199,4 @@ pub trait SpecificationExt<T>: Specification<T> + Sized + 'static { } } -impl<T, S: Specification<T> + 'static> SpecificationExt<T> for S {}
\ No newline at end of file +impl<T, S: Specification<T> + 'static> SpecificationExt<T> for S {} diff --git a/src/domain/unit_of_work.rs b/src/domain/unit_of_work.rs index db8294a..7d6a0e3 100644 --- a/src/domain/unit_of_work.rs +++ b/src/domain/unit_of_work.rs @@ -11,10 +11,10 @@ pub trait UnitOfWork: Send + Sync { pub trait Transaction: Send + Sync { /// Commit all changes in this transaction fn commit(self: Box<Self>) -> Result<()>; - + /// Rollback all changes in this transaction fn rollback(self: Box<Self>) -> Result<()>; - + /// Get repositories within this transaction context fn client_repository(&self) -> Arc<dyn crate::domain::DomainClientRepository>; fn auth_code_repository(&self) -> Arc<dyn crate::domain::DomainAuthCodeRepository>; @@ -31,7 +31,7 @@ impl OAuthUnitOfWork { pub fn new(uow: Arc<dyn UnitOfWork>) -> Self { Self { uow } } - + /// Execute OAuth2 authorization code exchange atomically pub fn exchange_authorization_code<F>(&self, operation: F) -> Result<()> where @@ -46,7 +46,7 @@ impl OAuthUnitOfWork { } } } - + /// Execute token refresh atomically pub fn refresh_tokens<F>(&self, operation: F) -> Result<()> where @@ -61,4 +61,4 @@ impl OAuthUnitOfWork { } } } -}
\ No newline at end of file +} @@ -5,6 +5,7 @@ pub mod database; pub mod domain; pub mod http; pub mod keys; +pub mod migration_discovery; pub mod migrations; pub mod oauth; pub mod repositories; diff --git a/src/migration_discovery.rs b/src/migration_discovery.rs new file mode 100644 index 0000000..fabc660 --- /dev/null +++ b/src/migration_discovery.rs @@ -0,0 +1,296 @@ +use anyhow::{Result, anyhow}; +use chrono::Utc; +use std::collections::BTreeMap; +use std::path::Path; + +#[derive(Debug, Clone)] +pub struct Migration { + pub version: i64, + pub name: String, + pub sql: String, +} + +/// Migration discovery that reads migration files from the filesystem at runtime +pub struct RuntimeMigrationDiscovery { + migrations_dir: std::path::PathBuf, + migrations: BTreeMap<i64, Migration>, +} + +impl RuntimeMigrationDiscovery { + pub fn new<P: AsRef<Path>>(migrations_dir: P) -> Result<Self> { + let migrations_dir = migrations_dir.as_ref().to_path_buf(); + let mut discovery = Self { + migrations_dir, + migrations: BTreeMap::new(), + }; + discovery.discover_migrations()?; + Ok(discovery) + } + + fn discover_migrations(&mut self) -> Result<()> { + if !self.migrations_dir.exists() { + return Err(anyhow!( + "Migrations directory does not exist: {:?}", + self.migrations_dir + )); + } + + let entries = std::fs::read_dir(&self.migrations_dir)?; + + for entry in entries { + let entry = entry?; + let path = entry.path(); + + if path.is_file() && path.extension().map_or(false, |ext| ext == "sql") { + let migration = self.parse_migration_file(&path)?; + self.migrations.insert(migration.version, migration); + } + } + + Ok(()) + } + + fn parse_migration_file(&self, path: &Path) -> Result<Migration> { + let file_name = path + .file_name() + .and_then(|n| n.to_str()) + .ok_or_else(|| anyhow!("Invalid migration file name: {:?}", path))?; + + let (version, name) = parse_migration_filename(file_name)?; + let sql = std::fs::read_to_string(path)?; + + Ok(Migration { version, name, sql }) + } + + pub fn get_migrations(&self) -> Vec<&Migration> { + self.migrations.values().collect() + } + + pub fn get_migration(&self, version: i64) -> Option<&Migration> { + self.migrations.get(&version) + } + + pub fn get_pending_migrations(&self, current_version: i64) -> Vec<&Migration> { + self.migrations + .values() + .filter(|m| m.version > current_version) + .collect() + } + + pub fn refresh(&mut self) -> Result<()> { + self.migrations.clear(); + self.discover_migrations() + } +} + +/// Parse migration filename to extract timestamp and name +/// Expected format: "20231201123456_initial_schema.sql" -> (20231201123456, "initial_schema") +fn parse_migration_filename(filename: &str) -> Result<(i64, String)> { + if !filename.ends_with(".sql") { + return Err(anyhow!( + "Migration file must have .sql extension: {}", + filename + )); + } + + let name_without_ext = &filename[..filename.len() - 4]; + + // Find first underscore + let underscore_pos = name_without_ext.find('_').ok_or_else(|| { + anyhow!( + "Migration filename must be in format 'YYYYMMDDHHMMSS_name.sql': {}", + filename + ) + })?; + + let timestamp_str = &name_without_ext[..underscore_pos]; + let name = &name_without_ext[underscore_pos + 1..]; + + // Validate timestamp format (should be 14 digits for YYYYMMDDHHMMSS) + if timestamp_str.len() != 14 { + return Err(anyhow!( + "Timestamp must be 14 digits (YYYYMMDDHHMMSS) in migration filename: {}", + filename + )); + } + + let timestamp = timestamp_str + .parse::<i64>() + .map_err(|_| anyhow!("Invalid timestamp in migration filename: {}", filename))?; + + // Basic validation of timestamp format + let year = timestamp / 10000000000; + let month = (timestamp / 100000000) % 100; + let day = (timestamp / 1000000) % 100; + + if year < 2000 || year > 3000 { + return Err(anyhow!("Invalid year in timestamp: {}", timestamp_str)); + } + if month < 1 || month > 12 { + return Err(anyhow!("Invalid month in timestamp: {}", timestamp_str)); + } + if day < 1 || day > 31 { + return Err(anyhow!("Invalid day in timestamp: {}", timestamp_str)); + } + + Ok((timestamp, name.to_string())) +} + +/// Generate a timestamp prefix for a new migration file +/// Returns a string in format YYYYMMDDHHMMSS +pub fn generate_migration_timestamp() -> String { + let now = Utc::now(); + now.format("%Y%m%d%H%M%S").to_string() +} + +/// Generate a full migration filename with timestamp prefix +/// Example: generate_migration_filename("add_users_table") -> "20231201123456_add_users_table.sql" +pub fn generate_migration_filename(name: &str) -> String { + format!("{}_{}.sql", generate_migration_timestamp(), name) +} + +/// Trait for migration discovery to allow different implementations +pub trait MigrationDiscovery { + fn get_migrations(&self) -> Vec<&Migration>; + fn get_pending_migrations(&self, current_version: i64) -> Vec<&Migration>; +} + +impl MigrationDiscovery for RuntimeMigrationDiscovery { + fn get_migrations(&self) -> Vec<&Migration> { + self.get_migrations() + } + + fn get_pending_migrations(&self, current_version: i64) -> Vec<&Migration> { + self.get_pending_migrations(current_version) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs; + use tempfile::TempDir; + + #[test] + fn test_parse_migration_filename() { + let cases = vec![ + ( + "20231201123456_initial_schema.sql", + Ok((20231201123456, "initial_schema".to_string())), + ), + ( + "20231202140000_add_users_table.sql", + Ok((20231202140000, "add_users_table".to_string())), + ), + ( + "20240315091530_complex_migration_name.sql", + Ok((20240315091530, "complex_migration_name".to_string())), + ), + ("invalid.sql", Err(())), // No underscore + ("abc_invalid.sql", Err(())), // Non-numeric timestamp + ("20231201123456_no_extension", Err(())), // No .sql extension + ("123_too_short.sql", Err(())), // Timestamp too short + ("202312011234567_too_long.sql", Err(())), // Timestamp too long + ("20231300123456_invalid_month.sql", Err(())), // Invalid month + ("20231232123456_invalid_day.sql", Err(())), // Invalid day + ("19991201123456_invalid_year.sql", Err(())), // Year too old + ]; + + for (input, expected) in cases { + let result = parse_migration_filename(input); + match expected { + Ok((expected_timestamp, expected_name)) => { + assert!( + result.is_ok(), + "Expected success for {}, got error: {:?}", + input, + result + ); + let (timestamp, name) = result.unwrap(); + assert_eq!( + timestamp, expected_timestamp, + "Timestamp mismatch for {}", + input + ); + assert_eq!(name, expected_name, "Name mismatch for {}", input); + } + Err(_) => { + assert!( + result.is_err(), + "Expected error for {}, got success: {:?}", + input, + result + ); + } + } + } + } + + #[test] + fn test_runtime_migration_discovery() { + // Create a temporary directory with migration files + let temp_dir = TempDir::new().unwrap(); + let migrations_dir = temp_dir.path(); + + // Create test migration files with timestamp format + fs::write( + migrations_dir.join("20231201120000_initial_schema.sql"), + "CREATE TABLE users (id INTEGER PRIMARY KEY);", + ) + .unwrap(); + + fs::write( + migrations_dir.join("20231201130000_add_posts.sql"), + "CREATE TABLE posts (id INTEGER PRIMARY KEY, user_id INTEGER);", + ) + .unwrap(); + + // Test discovery + let discovery = RuntimeMigrationDiscovery::new(migrations_dir).unwrap(); + let migrations = discovery.get_migrations(); + + assert_eq!(migrations.len(), 2); + + // Migrations should be sorted by timestamp + assert_eq!(migrations[0].version, 20231201120000); + assert_eq!(migrations[0].name, "initial_schema"); + assert_eq!(migrations[1].version, 20231201130000); + assert_eq!(migrations[1].name, "add_posts"); + } + + #[test] + fn test_pending_migrations() { + // Create a temporary directory with migration files + let temp_dir = TempDir::new().unwrap(); + let migrations_dir = temp_dir.path(); + + fs::write( + migrations_dir.join("20231201120000_initial_schema.sql"), + "CREATE TABLE users (id INTEGER PRIMARY KEY);", + ) + .unwrap(); + + fs::write( + migrations_dir.join("20231201130000_add_posts.sql"), + "CREATE TABLE posts (id INTEGER PRIMARY KEY, user_id INTEGER);", + ) + .unwrap(); + + let discovery = RuntimeMigrationDiscovery::new(migrations_dir).unwrap(); + + // No pending migrations if we're at the latest timestamp + let pending = discovery.get_pending_migrations(20231201130000); + assert!(pending.is_empty()); + + // All migrations are pending if we're at timestamp 0 + let pending = discovery.get_pending_migrations(0); + assert_eq!(pending.len(), 2); + assert_eq!(pending[0].version, 20231201120000); + assert_eq!(pending[1].version, 20231201130000); + + // One migration pending if we're at the first timestamp + let pending = discovery.get_pending_migrations(20231201120000); + assert_eq!(pending.len(), 1); + assert_eq!(pending[0].version, 20231201130000); + } +} diff --git a/src/migrations.rs b/src/migrations.rs index 61c5b19..021d525 100644 --- a/src/migrations.rs +++ b/src/migrations.rs @@ -1,40 +1,31 @@ +use crate::migration_discovery::{Migration, RuntimeMigrationDiscovery}; use anyhow::Result; use rusqlite::Connection; - -pub struct Migration { - pub version: i32, - pub name: &'static str, - pub sql: &'static str, -} - -const MIGRATIONS: &[Migration] = &[ - Migration { - version: 1, - name: "initial_schema", - sql: include_str!("../migrations/001_initial_schema.sql"), - }, - // Add more migrations here as needed - // Migration { - // version: 2, - // name: "add_user_table", - // sql: include_str!("../migrations/002_add_user_table.sql"), - // }, -]; +use std::path::Path; pub struct MigrationRunner<'a> { conn: &'a Connection, + discovery: RuntimeMigrationDiscovery, } impl<'a> MigrationRunner<'a> { - pub fn new(conn: &'a Connection) -> Self { - Self { conn } + pub fn new(conn: &'a Connection) -> Result<Self> { + // Default to migrations directory relative to project root + let migrations_dir = Path::new("migrations"); + let discovery = RuntimeMigrationDiscovery::new(migrations_dir)?; + Ok(Self { conn, discovery }) + } + + pub fn new_with_path<P: AsRef<Path>>(conn: &'a Connection, migrations_dir: P) -> Result<Self> { + let discovery = RuntimeMigrationDiscovery::new(migrations_dir)?; + Ok(Self { conn, discovery }) } pub fn run_migrations(&self) -> Result<()> { // Create migrations table if it doesn't exist self.conn.execute( "CREATE TABLE IF NOT EXISTS schema_migrations ( - version INTEGER PRIMARY KEY, + version BIGINT PRIMARY KEY, name TEXT NOT NULL, applied_at TEXT NOT NULL )", @@ -46,22 +37,23 @@ impl<'a> MigrationRunner<'a> { println!("Current database version: {}", current_version); - // Run pending migrations - for migration in MIGRATIONS { - if migration.version > current_version { - println!( - "Running migration {}: {}", - migration.version, migration.name - ); - self.run_migration(migration)?; - } + // Get pending migrations from discovery system + let pending_migrations = self.discovery.get_pending_migrations(current_version); + + // Run each pending migration + for migration in pending_migrations { + println!( + "Running migration {}: {}", + migration.version, migration.name + ); + self.run_migration(migration)?; } println!("All migrations completed successfully"); Ok(()) } - fn get_current_version(&self) -> Result<i32> { + fn get_current_version(&self) -> Result<i64> { // Check if schema_migrations table exists first let table_exists = self.conn.query_row( "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'", @@ -75,7 +67,7 @@ impl<'a> MigrationRunner<'a> { let version = self.conn.query_row( "SELECT COALESCE(MAX(version), 0) FROM schema_migrations", [], - |row| row.get::<_, i32>(0), + |row| row.get::<_, i64>(0), )?; Ok(version) } @@ -88,14 +80,14 @@ impl<'a> MigrationRunner<'a> { fn run_migration(&self, migration: &Migration) -> Result<()> { // Execute the migration SQL - self.conn.execute_batch(migration.sql)?; + self.conn.execute_batch(&migration.sql)?; // Record the migration as applied self.conn.execute( "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)", [ &migration.version.to_string(), - migration.name, + &migration.name, &chrono::Utc::now().to_rfc3339(), ], )?; @@ -103,7 +95,7 @@ impl<'a> MigrationRunner<'a> { Ok(()) } - pub fn rollback_to_version(&self, target_version: i32) -> Result<()> { + pub fn rollback_to_version(&self, target_version: i64) -> Result<()> { println!("Rolling back to version {}", target_version); // This is a simplified rollback - in practice you'd need down migrations @@ -127,7 +119,7 @@ impl<'a> MigrationRunner<'a> { let migrations = stmt.query_map([], |row| { Ok(( - row.get::<_, i32>(0)?, + row.get::<_, i64>(0)?, row.get::<_, String>(1)?, row.get::<_, String>(2)?, )) @@ -143,13 +135,12 @@ impl<'a> MigrationRunner<'a> { // Show pending migrations let current_version = self.get_current_version()?; - for migration in MIGRATIONS { - if migration.version > current_version { - println!( - "⏳ Migration {}: {} (pending)", - migration.version, migration.name - ); - } + let pending_migrations = self.discovery.get_pending_migrations(current_version); + for migration in pending_migrations { + println!( + "⏳ Migration {}: {} (pending)", + migration.version, migration.name + ); } Ok(()) @@ -162,8 +153,21 @@ mod tests { #[test] fn test_migration_runner() { + use std::fs; + use tempfile::TempDir; + + // Create a temporary directory with a test migration + let temp_dir = TempDir::new().unwrap(); + let migrations_dir = temp_dir.path(); + + fs::write( + migrations_dir.join("20231201120000_test_migration.sql"), + "CREATE TABLE test_table (id INTEGER PRIMARY KEY);", + ) + .unwrap(); + let conn = Connection::open_in_memory().unwrap(); - let runner = MigrationRunner::new(&conn); + let runner = MigrationRunner::new_with_path(&conn, migrations_dir).unwrap(); // Should start with version 0 assert_eq!(runner.get_current_version().unwrap(), 0); @@ -171,10 +175,7 @@ mod tests { // Run migrations runner.run_migrations().unwrap(); - // Should now be at latest version - assert_eq!( - runner.get_current_version().unwrap(), - MIGRATIONS.len() as i32 - ); + // Should now be at latest version (timestamp of our migration) + assert_eq!(runner.get_current_version().unwrap(), 20231201120000); } } |
