summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock49
-rw-r--r--Cargo.toml11
-rw-r--r--migrations/20241201000000_initial_schema.sql (renamed from migrations/001_initial_schema.sql)0
-rw-r--r--spec/integration/server_spec.rb4
-rw-r--r--src/bin/generate_migration.rs22
-rw-r--r--src/bin/migrate.rs8
-rw-r--r--src/bin/test.rs154
-rw-r--r--src/database.rs2
-rw-r--r--src/domain/conversions.rs40
-rw-r--r--src/domain/dto.rs2
-rw-r--r--src/domain/mappers.rs40
-rw-r--r--src/domain/mod.rs2
-rw-r--r--src/domain/models.rs8
-rw-r--r--src/domain/queries.rs19
-rw-r--r--src/domain/repositories.rs9
-rw-r--r--src/domain/services.rs70
-rw-r--r--src/domain/specifications.rs40
-rw-r--r--src/domain/unit_of_work.rs10
-rw-r--r--src/lib.rs1
-rw-r--r--src/migration_discovery.rs296
-rw-r--r--src/migrations.rs103
21 files changed, 763 insertions, 127 deletions
diff --git a/Cargo.lock b/Cargo.lock
index f91c48a..68cadb0 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -220,6 +220,16 @@ dependencies = [
]
[[package]]
+name = "errno"
+version = "0.3.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cea14ef9355e3beab063703aa9dab15afd25f0667c341310c1e5274bb1d0da18"
+dependencies = [
+ "libc",
+ "windows-sys 0.59.0",
+]
+
+[[package]]
name = "fallible-iterator"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -232,6 +242,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
[[package]]
+name = "fastrand"
+version = "2.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
+
+[[package]]
name = "form_urlencoded"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -494,6 +510,12 @@ dependencies = [
]
[[package]]
+name = "linux-raw-sys"
+version = "0.9.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12"
+
+[[package]]
name = "litemap"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -843,6 +865,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "989e6739f80c4ad5b13e0fd7fe89531180375b18520cc8c82080e4dc4035b84f"
[[package]]
+name = "rustix"
+version = "1.0.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266"
+dependencies = [
+ "bitflags",
+ "errno",
+ "libc",
+ "linux-raw-sys",
+ "windows-sys 0.59.0",
+]
+
+[[package]]
name = "rustversion"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -993,6 +1028,7 @@ dependencies = [
"serde_json",
"sha2",
"subtle",
+ "tempfile",
"tokio",
"url",
"urlencoding",
@@ -1028,6 +1064,19 @@ dependencies = [
]
[[package]]
+name = "tempfile"
+version = "3.20.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1"
+dependencies = [
+ "fastrand",
+ "getrandom 0.3.3",
+ "once_cell",
+ "rustix",
+ "windows-sys 0.59.0",
+]
+
+[[package]]
name = "thiserror"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 5c2f1e7..853f78c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,6 +15,14 @@ path = "src/bin/migrate.rs"
name = "debug"
path = "src/bin/debug.rs"
+[[bin]]
+name = "generate_migration"
+path = "src/bin/generate_migration.rs"
+
+[[bin]]
+name = "test"
+path = "src/bin/test.rs"
+
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
@@ -31,3 +39,6 @@ chrono = { version = "0.4", features = ["serde"] }
tokio = { version = "1.0", features = ["full"] }
anyhow = "1.0"
subtle = "2.4"
+
+[dev-dependencies]
+tempfile = "3.0"
diff --git a/migrations/001_initial_schema.sql b/migrations/20241201000000_initial_schema.sql
index 8796157..8796157 100644
--- a/migrations/001_initial_schema.sql
+++ b/migrations/20241201000000_initial_schema.sql
diff --git a/spec/integration/server_spec.rb b/spec/integration/server_spec.rb
index 229203c..36cb81c 100644
--- a/spec/integration/server_spec.rb
+++ b/spec/integration/server_spec.rb
@@ -62,6 +62,10 @@ RSpec.describe "Server" do
end
end
+ describe "GET /authorize" do
+ pending
+ end
+
# https://datatracker.ietf.org/doc/html/rfc8693#section-2.3
describe "POST /token" do
pending
diff --git a/src/bin/generate_migration.rs b/src/bin/generate_migration.rs
new file mode 100644
index 0000000..c72db66
--- /dev/null
+++ b/src/bin/generate_migration.rs
@@ -0,0 +1,22 @@
+use std::env;
+use sts::migration_discovery::{generate_migration_filename, generate_migration_timestamp};
+
+fn main() {
+ let args: Vec<String> = env::args().collect();
+
+ if args.len() < 2 {
+ eprintln!("Usage: cargo run --bin generate_migration <migration_name>");
+ eprintln!("Example: cargo run --bin generate_migration add_users_table");
+ return;
+ }
+
+ let migration_name = &args[1];
+ let filename = generate_migration_filename(migration_name);
+ let timestamp = generate_migration_timestamp();
+
+ println!("Generated migration filename: {}", filename);
+ println!("Timestamp: {}", timestamp);
+ println!("");
+ println!("To create the migration file:");
+ println!("touch migrations/{}", filename);
+}
diff --git a/src/bin/migrate.rs b/src/bin/migrate.rs
index 9a0bab9..fbf5183 100644
--- a/src/bin/migrate.rs
+++ b/src/bin/migrate.rs
@@ -13,7 +13,7 @@ fn main() -> Result<()> {
let config = Config::from_env();
let conn = Connection::open(&config.database_path)?;
- let runner = MigrationRunner::new(&conn);
+ let runner = MigrationRunner::new(&conn)?;
match args[1].as_str() {
"up" => {
@@ -29,7 +29,7 @@ fn main() -> Result<()> {
eprintln!("Usage: cargo run --bin migrate rollback <version>");
return Ok(());
}
- let version: i32 = args[2]
+ let version: i64 = args[2]
.parse()
.map_err(|_| anyhow::anyhow!("Invalid version number: {}", args[2]))?;
runner.rollback_to_version(version)?;
@@ -50,7 +50,7 @@ fn print_usage() {
println!("Usage:");
println!(" cargo run --bin migrate up # Run pending migrations");
println!(" cargo run --bin migrate status # Show migration status");
- println!(" cargo run --bin migrate rollback <version> # Rollback to version");
+ println!(" cargo run --bin migrate rollback <timestamp> # Rollback to timestamp");
println!("");
println!("Environment Variables:");
println!(" DATABASE_PATH Path to SQLite database (default: oauth.db)");
@@ -58,5 +58,5 @@ fn print_usage() {
println!("Examples:");
println!(" cargo run --bin migrate up");
println!(" cargo run --bin migrate status");
- println!(" cargo run --bin migrate rollback 0");
+ println!(" cargo run --bin migrate rollback 20231201120000");
}
diff --git a/src/bin/test.rs b/src/bin/test.rs
new file mode 100644
index 0000000..d2704c4
--- /dev/null
+++ b/src/bin/test.rs
@@ -0,0 +1,154 @@
+use std::env;
+use std::process::{Command, exit};
+
+fn main() {
+ let args: Vec<String> = env::args().collect();
+
+ if args.len() < 2 {
+ print_usage();
+ return;
+ }
+
+ match args[1].as_str() {
+ "unit" => {
+ println!("Running unit tests...");
+ run_command(&["cargo", "test", "--lib"]);
+ }
+ "integration" => {
+ println!("Running integration tests...");
+ run_command_allow_failure(&["bundle", "exec", "rspec"]);
+ }
+ "all" => {
+ println!("Running all tests...");
+ println!("===================");
+ println!();
+
+ println!("1. Running Rust unit tests...");
+ run_command(&["cargo", "test"]);
+ println!();
+
+ println!("2. Running Ruby integration tests...");
+ run_command_allow_failure(&["bundle", "exec", "rspec"]);
+ }
+ "watch" => {
+ println!("Running tests in watch mode...");
+ if args.len() > 2 && args[2] == "integration" {
+ run_command(&["bundle", "exec", "guard"]);
+ } else {
+ run_command(&["cargo", "watch", "-x", "test"]);
+ }
+ }
+ "coverage" => {
+ println!("Running tests with coverage...");
+ run_command(&["cargo", "tarpaulin", "--out", "Html"]);
+ }
+ "check" => {
+ println!("Running cargo check...");
+ run_command(&["cargo", "check"]);
+ }
+ "lint" => {
+ println!("Running linting...");
+ run_command(&["cargo", "clippy", "--", "-D", "warnings"]);
+ }
+ "fmt" => {
+ println!("Running code formatting...");
+ run_command(&["cargo", "fmt"]);
+ }
+ "clean" => {
+ println!("Cleaning test artifacts...");
+ run_command(&["cargo", "clean"]);
+ run_command(&["rm", "-f", "oauth.db"]);
+ run_command(&["rm", "-f", "test.db"]);
+ }
+ "server" => {
+ println!("Starting test server...");
+ run_command(&["cargo", "run", "--bin", "sts"]);
+ }
+ "migrate" => {
+ println!("Running migrations for tests...");
+ run_command(&["cargo", "run", "--bin", "migrate", "up"]);
+ }
+ "reset" => {
+ println!("Resetting test environment...");
+ run_command(&["rm", "-f", "oauth.db"]);
+ run_command(&["cargo", "run", "--bin", "migrate", "up"]);
+ }
+ _ => {
+ eprintln!("Error: Unknown command '{}'", args[1]);
+ print_usage();
+ exit(1);
+ }
+ }
+}
+
+fn run_command(cmd: &[&str]) {
+ let mut command = Command::new(cmd[0]);
+ if cmd.len() > 1 {
+ command.args(&cmd[1..]);
+ }
+
+ let status = command.status().unwrap_or_else(|err| {
+ eprintln!("Failed to execute command '{:?}': {}", cmd, err);
+ exit(1);
+ });
+
+ if !status.success() {
+ eprintln!(
+ "Command '{:?}' failed with exit code: {:?}",
+ cmd,
+ status.code()
+ );
+ exit(1);
+ }
+}
+
+fn run_command_allow_failure(cmd: &[&str]) {
+ let mut command = Command::new(cmd[0]);
+ if cmd.len() > 1 {
+ command.args(&cmd[1..]);
+ }
+
+ let status = command.status().unwrap_or_else(|err| {
+ eprintln!("Failed to execute command '{:?}': {}", cmd, err);
+ exit(1);
+ });
+
+ if !status.success() {
+ eprintln!(
+ "Command '{:?}' completed with exit code: {:?}",
+ cmd,
+ status.code()
+ );
+ // Don't exit, just report the failure
+ }
+}
+
+fn print_usage() {
+ println!("OAuth2 STS Test Runner");
+ println!("=====================");
+ println!();
+ println!("Usage:");
+ println!(" cargo run --bin test <command>");
+ println!();
+ println!("Commands:");
+ println!(" unit Run Rust unit tests only");
+ println!(" integration Run Ruby integration tests only");
+ println!(" all Run all tests (unit + integration)");
+ println!(" watch Run tests in watch mode");
+ println!(" watch integration Run integration tests in watch mode");
+ println!(" coverage Run tests with coverage report");
+ println!(" check Run cargo check");
+ println!(" lint Run clippy linting");
+ println!(" fmt Run code formatting");
+ println!(" clean Clean test artifacts and databases");
+ println!(" server Start test server");
+ println!(" migrate Run database migrations");
+ println!(" reset Reset test environment (clean DB + migrate)");
+ println!();
+ println!("Examples:");
+ println!(" cargo run --bin test unit");
+ println!(" cargo run --bin test all");
+ println!(" cargo run --bin test watch");
+ println!(" cargo run --bin test coverage");
+ println!(" cargo run --bin test reset");
+}
diff --git a/src/database.rs b/src/database.rs
index a91579b..178eee3 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -116,7 +116,7 @@ impl Database {
fn run_migrations(&self) -> Result<()> {
// Use the migration system instead of duplicated schema
- let migration_runner = crate::migrations::MigrationRunner::new(&self.conn);
+ let migration_runner = crate::migrations::MigrationRunner::new(&self.conn)?;
migration_runner.run_migrations()?;
Ok(())
}
diff --git a/src/domain/conversions.rs b/src/domain/conversions.rs
index 53e6062..13a1b9b 100644
--- a/src/domain/conversions.rs
+++ b/src/domain/conversions.rs
@@ -1,4 +1,4 @@
-use crate::database::{DbAccessToken, DbAuthCode, DbAuditLog, DbOAuthClient};
+use crate::database::{DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient};
use crate::domain::models::*;
use anyhow::Result;
@@ -18,9 +18,21 @@ pub trait ToDb<T> {
impl FromDb<DbOAuthClient> for OAuthClient {
fn from_db(db_client: DbOAuthClient) -> Result<Self> {
let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?;
- let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect();
- let grant_types: Vec<String> = db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect();
- let response_types: Vec<String> = db_client.response_types.split_whitespace().map(|s| s.to_string()).collect();
+ let scopes: Vec<String> = db_client
+ .scopes
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
+ let grant_types: Vec<String> = db_client
+ .grant_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
+ let response_types: Vec<String> = db_client
+ .response_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
Ok(OAuthClient {
client_id: db_client.client_id,
@@ -57,8 +69,13 @@ impl ToDb<DbOAuthClient> for OAuthClient {
// Authorization Code conversions
impl FromDb<DbAuthCode> for AuthorizationCode {
fn from_db(db_code: DbAuthCode) -> Result<Self> {
- let scopes = db_code.scope
- .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect())
+ let scopes = db_code
+ .scope
+ .map(|s| {
+ s.split_whitespace()
+ .map(|scope| scope.to_string())
+ .collect()
+ })
.unwrap_or_default();
Ok(AuthorizationCode {
@@ -103,8 +120,13 @@ impl ToDb<DbAuthCode> for AuthorizationCode {
// Access Token conversions
impl FromDb<DbAccessToken> for AccessToken {
fn from_db(db_token: DbAccessToken) -> Result<Self> {
- let scopes = db_token.scope
- .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect())
+ let scopes = db_token
+ .scope
+ .map(|s| {
+ s.split_whitespace()
+ .map(|scope| scope.to_string())
+ .collect()
+ })
.unwrap_or_default();
Ok(AccessToken {
@@ -171,4 +193,4 @@ impl ToDb<DbAuditLog> for AuditEvent {
success: self.success,
})
}
-} \ No newline at end of file
+}
diff --git a/src/domain/dto.rs b/src/domain/dto.rs
index 336db61..1c342bc 100644
--- a/src/domain/dto.rs
+++ b/src/domain/dto.rs
@@ -131,4 +131,4 @@ impl From<crate::domain::OAuthError> for ErrorResponseDto {
error_uri: error.uri,
}
}
-} \ No newline at end of file
+}
diff --git a/src/domain/mappers.rs b/src/domain/mappers.rs
index 6efe276..405b08b 100644
--- a/src/domain/mappers.rs
+++ b/src/domain/mappers.rs
@@ -1,4 +1,4 @@
-use crate::database::{DbAccessToken, DbAuthCode, DbAuditLog, DbOAuthClient};
+use crate::database::{DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient};
use crate::domain::models::*;
use anyhow::Result;
@@ -14,9 +14,21 @@ pub struct OAuthClientMapper;
impl DataMapper<OAuthClient, DbOAuthClient> for OAuthClientMapper {
fn to_domain(&self, db_client: DbOAuthClient) -> Result<OAuthClient> {
let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?;
- let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect();
- let grant_types: Vec<String> = db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect();
- let response_types: Vec<String> = db_client.response_types.split_whitespace().map(|s| s.to_string()).collect();
+ let scopes: Vec<String> = db_client
+ .scopes
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
+ let grant_types: Vec<String> = db_client
+ .grant_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
+ let response_types: Vec<String> = db_client
+ .response_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
Ok(OAuthClient {
client_id: db_client.client_id,
@@ -53,8 +65,13 @@ pub struct AuthCodeMapper;
impl DataMapper<AuthorizationCode, DbAuthCode> for AuthCodeMapper {
fn to_domain(&self, db_code: DbAuthCode) -> Result<AuthorizationCode> {
- let scopes = db_code.scope
- .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect())
+ let scopes = db_code
+ .scope
+ .map(|s| {
+ s.split_whitespace()
+ .map(|scope| scope.to_string())
+ .collect()
+ })
.unwrap_or_default();
Ok(AuthorizationCode {
@@ -99,8 +116,13 @@ pub struct AccessTokenMapper;
impl DataMapper<AccessToken, DbAccessToken> for AccessTokenMapper {
fn to_domain(&self, db_token: DbAccessToken) -> Result<AccessToken> {
- let scopes = db_token.scope
- .map(|s| s.split_whitespace().map(|scope| scope.to_string()).collect())
+ let scopes = db_token
+ .scope
+ .map(|s| {
+ s.split_whitespace()
+ .map(|scope| scope.to_string())
+ .collect()
+ })
.unwrap_or_default();
Ok(AccessToken {
@@ -206,4 +228,4 @@ impl Default for MapperRegistry {
fn default() -> Self {
Self::new()
}
-} \ No newline at end of file
+}
diff --git a/src/domain/mod.rs b/src/domain/mod.rs
index 9a8bfca..7ba3b00 100644
--- a/src/domain/mod.rs
+++ b/src/domain/mod.rs
@@ -16,4 +16,4 @@ pub use queries::*;
pub use repositories::*;
pub use services::*;
pub use specifications::*;
-pub use unit_of_work::*; \ No newline at end of file
+pub use unit_of_work::*;
diff --git a/src/domain/models.rs b/src/domain/models.rs
index 85b554f..26e6df3 100644
--- a/src/domain/models.rs
+++ b/src/domain/models.rs
@@ -108,9 +108,9 @@ pub struct AuthorizationRequest {
#[derive(Debug, Clone, PartialEq)]
pub struct TokenRequest {
pub grant_type: String,
- pub code: Option<String>, // For authorization_code grant
- pub refresh_token: Option<String>, // For refresh_token grant
- pub redirect_uri: Option<String>, // For authorization_code grant
+ pub code: Option<String>, // For authorization_code grant
+ pub refresh_token: Option<String>, // For refresh_token grant
+ pub redirect_uri: Option<String>, // For authorization_code grant
pub client_id: String,
pub client_secret: Option<String>,
// PKCE
@@ -220,4 +220,4 @@ impl Scope {
description: Some("Access to user email address".to_string()),
}
}
-} \ No newline at end of file
+}
diff --git a/src/domain/queries.rs b/src/domain/queries.rs
index d4eb19e..88fb480 100644
--- a/src/domain/queries.rs
+++ b/src/domain/queries.rs
@@ -277,31 +277,24 @@ pub struct CommonQueries;
impl CommonQueries {
/// Get recent failed login attempts (security monitoring)
pub fn recent_failed_logins() -> FailedAuthQuery {
- FailedAuthQuery::new()
- .last_24_hours()
- .min_attempts(3)
+ FailedAuthQuery::new().last_24_hours().min_attempts(3)
}
/// Get audit trail for a specific client
pub fn client_audit_trail(client_id: &str) -> AuditEventsQuery {
- AuditEventsQuery::new()
- .for_client(client_id)
- .limit(1000)
+ AuditEventsQuery::new().for_client(client_id).limit(1000)
}
/// Get token usage statistics for the last 30 days
pub fn monthly_token_usage() -> TokenUsageQuery {
let now = Utc::now();
let thirty_days_ago = now - chrono::Duration::days(30);
-
- TokenUsageQuery::new(TokenUsageGroupBy::Day)
- .date_range(thirty_days_ago, now)
+
+ TokenUsageQuery::new(TokenUsageGroupBy::Day).date_range(thirty_days_ago, now)
}
/// Get all active clients with OpenID scope
pub fn openid_clients() -> OAuthClientsQuery {
- OAuthClientsQuery::new()
- .active_only()
- .with_scope("openid")
+ OAuthClientsQuery::new().active_only().with_scope("openid")
}
-} \ No newline at end of file
+}
diff --git a/src/domain/repositories.rs b/src/domain/repositories.rs
index 1aa4f33..3622373 100644
--- a/src/domain/repositories.rs
+++ b/src/domain/repositories.rs
@@ -36,6 +36,11 @@ pub trait DomainAuditRepository: Send + Sync {
/// Domain-focused repository for rate limiting
pub trait DomainRateRepository: Send + Sync {
fn check_rate_limit(&self, limit: &RateLimit) -> Result<u32>; // Returns current count
- fn increment_rate_limit(&self, identifier: &str, endpoint: &str, window_minutes: u32) -> Result<u32>;
+ fn increment_rate_limit(
+ &self,
+ identifier: &str,
+ endpoint: &str,
+ window_minutes: u32,
+ ) -> Result<u32>;
fn cleanup_old_rate_limits(&self) -> Result<()>;
-} \ No newline at end of file
+}
diff --git a/src/domain/services.rs b/src/domain/services.rs
index 2c23cdc..0e22ddb 100644
--- a/src/domain/services.rs
+++ b/src/domain/services.rs
@@ -3,10 +3,22 @@ use anyhow::Result;
/// Domain service for OAuth2 authorization flow
pub trait AuthorizationService: Send + Sync {
- fn authorize(&self, request: &AuthorizationRequest, user: &User) -> Result<AuthorizationResult, OAuthError>;
+ fn authorize(
+ &self,
+ request: &AuthorizationRequest,
+ user: &User,
+ ) -> Result<AuthorizationResult, OAuthError>;
fn validate_client(&self, client_id: &str) -> Result<OAuthClient, OAuthError>;
- fn validate_redirect_uri(&self, client: &OAuthClient, redirect_uri: &str) -> Result<(), OAuthError>;
- fn validate_scopes(&self, client: &OAuthClient, requested_scopes: &[String]) -> Result<Vec<String>, OAuthError>;
+ fn validate_redirect_uri(
+ &self,
+ client: &OAuthClient,
+ redirect_uri: &str,
+ ) -> Result<(), OAuthError>;
+ fn validate_scopes(
+ &self,
+ client: &OAuthClient,
+ requested_scopes: &[String],
+ ) -> Result<Vec<String>, OAuthError>;
}
/// Domain service for OAuth2 token operations
@@ -23,28 +35,59 @@ pub trait ClientService: Send + Sync {
fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>;
fn update_client(&self, client: &OAuthClient) -> Result<()>;
fn delete_client(&self, client_id: &str) -> Result<()>;
- fn authenticate_client(&self, client_id: &str, client_secret: &str) -> Result<OAuthClient, OAuthError>;
+ fn authenticate_client(
+ &self,
+ client_id: &str,
+ client_secret: &str,
+ ) -> Result<OAuthClient, OAuthError>;
}
/// Domain service for user management
pub trait UserService: Send + Sync {
fn get_user(&self, user_id: &str) -> Result<Option<User>>;
fn authenticate_user(&self, username: &str, password: &str) -> Result<User, OAuthError>;
- fn is_user_authorized(&self, user: &User, client: &OAuthClient, scopes: &[String]) -> Result<bool>;
+ fn is_user_authorized(
+ &self,
+ user: &User,
+ client: &OAuthClient,
+ scopes: &[String],
+ ) -> Result<bool>;
}
/// Domain service for audit logging
pub trait AuditService: Send + Sync {
- fn log_authorization_attempt(&self, request: &AuthorizationRequest, user: Option<&User>, success: bool, ip_address: Option<&str>) -> Result<()>;
- fn log_token_request(&self, request: &TokenRequest, success: bool, ip_address: Option<&str>) -> Result<()>;
- fn log_token_introspection(&self, token_hash: &str, client_id: &str, success: bool) -> Result<()>;
+ fn log_authorization_attempt(
+ &self,
+ request: &AuthorizationRequest,
+ user: Option<&User>,
+ success: bool,
+ ip_address: Option<&str>,
+ ) -> Result<()>;
+ fn log_token_request(
+ &self,
+ request: &TokenRequest,
+ success: bool,
+ ip_address: Option<&str>,
+ ) -> Result<()>;
+ fn log_token_introspection(
+ &self,
+ token_hash: &str,
+ client_id: &str,
+ success: bool,
+ ) -> Result<()>;
fn log_token_revocation(&self, token_hash: &str, client_id: &str, success: bool) -> Result<()>;
}
/// Domain service for rate limiting
pub trait RateLimitService: Send + Sync {
fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<(), OAuthError>;
- fn is_rate_limited(&self, identifier: &str, endpoint: &str, max_requests: u32, window_minutes: u32) -> Result<bool>;
+ fn is_rate_limited(
+ &self,
+ identifier: &str,
+ endpoint: &str,
+ max_requests: u32,
+ window_minutes: u32,
+ ) -> Result<bool>;
}
/// Domain service for PKCE operations
@@ -57,7 +100,12 @@ pub trait PkceService: Send + Sync {
/// Domain service for JWT operations
pub trait JwtService: Send + Sync {
fn generate_access_token(&self, claims: &TokenClaims) -> Result<String>;
- fn generate_refresh_token(&self, client_id: &str, user_id: &str, scopes: &[String]) -> Result<String>;
+ fn generate_refresh_token(
+ &self,
+ client_id: &str,
+ user_id: &str,
+ scopes: &[String],
+ ) -> Result<String>;
fn validate_token(&self, token: &str) -> Result<TokenClaims>;
fn get_jwks(&self) -> Result<String>; // JSON Web Key Set
-} \ No newline at end of file
+}
diff --git a/src/domain/specifications.rs b/src/domain/specifications.rs
index 3237d1b..76aafcc 100644
--- a/src/domain/specifications.rs
+++ b/src/domain/specifications.rs
@@ -22,7 +22,7 @@ impl<T> Specification<T> for AndSpecification<T> {
fn is_satisfied_by(&self, candidate: &T) -> bool {
self.left.is_satisfied_by(candidate) && self.right.is_satisfied_by(candidate)
}
-
+
fn reason_for_failure(&self, candidate: &T) -> Option<String> {
if !self.left.is_satisfied_by(candidate) {
self.left.reason_for_failure(candidate)
@@ -40,7 +40,7 @@ impl Specification<OAuthClient> for ActiveClientSpecification {
fn is_satisfied_by(&self, client: &OAuthClient) -> bool {
client.is_active
}
-
+
fn reason_for_failure(&self, _client: &OAuthClient) -> Option<String> {
Some("Client is not active".to_string())
}
@@ -60,7 +60,7 @@ impl Specification<OAuthClient> for ValidRedirectUriSpecification {
fn is_satisfied_by(&self, client: &OAuthClient) -> bool {
client.redirect_uris.contains(&self.redirect_uri)
}
-
+
fn reason_for_failure(&self, _client: &OAuthClient) -> Option<String> {
Some(format!("Invalid redirect_uri: {}", self.redirect_uri))
}
@@ -78,15 +78,19 @@ impl SupportedScopesSpecification {
impl Specification<OAuthClient> for SupportedScopesSpecification {
fn is_satisfied_by(&self, client: &OAuthClient) -> bool {
- self.requested_scopes.iter().all(|scope| client.scopes.contains(scope))
+ self.requested_scopes
+ .iter()
+ .all(|scope| client.scopes.contains(scope))
}
-
+
fn reason_for_failure(&self, client: &OAuthClient) -> Option<String> {
- let unsupported: Vec<_> = self.requested_scopes.iter()
+ let unsupported: Vec<_> = self
+ .requested_scopes
+ .iter()
.filter(|scope| !client.scopes.contains(scope))
.cloned()
.collect();
-
+
if !unsupported.is_empty() {
Some(format!("Unsupported scopes: {}", unsupported.join(", ")))
} else {
@@ -101,7 +105,7 @@ impl Specification<AuthorizationCode> for UnusedAuthCodeSpecification {
fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool {
!code.is_used
}
-
+
fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> {
Some("Authorization code has already been used".to_string())
}
@@ -112,7 +116,7 @@ impl Specification<AuthorizationCode> for ValidAuthCodeSpecification {
fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool {
chrono::Utc::now() < code.expires_at
}
-
+
fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> {
Some("Authorization code has expired".to_string())
}
@@ -132,7 +136,7 @@ impl Specification<AuthorizationCode> for MatchingClientSpecification {
fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool {
code.client_id == self.client_id
}
-
+
fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> {
Some("Client ID mismatch".to_string())
}
@@ -153,14 +157,18 @@ impl Specification<AuthorizationCode> for ValidPkceSpecification {
fn is_satisfied_by(&self, code: &AuthorizationCode) -> bool {
if let Some(challenge) = &code.code_challenge {
let method = code.code_challenge_method.as_deref().unwrap_or("plain");
- crate::oauth::pkce::verify_code_challenge(&self.code_verifier, challenge,
- &crate::oauth::pkce::CodeChallengeMethod::from_str(method).unwrap_or(crate::oauth::pkce::CodeChallengeMethod::Plain)
- ).is_ok()
+ crate::oauth::pkce::verify_code_challenge(
+ &self.code_verifier,
+ challenge,
+ &crate::oauth::pkce::CodeChallengeMethod::from_str(method)
+ .unwrap_or(crate::oauth::pkce::CodeChallengeMethod::Plain),
+ )
+ .is_ok()
} else {
true // No PKCE required
}
}
-
+
fn reason_for_failure(&self, _code: &AuthorizationCode) -> Option<String> {
Some("PKCE verification failed".to_string())
}
@@ -172,7 +180,7 @@ impl Specification<AccessToken> for ValidTokenSpecification {
fn is_satisfied_by(&self, token: &AccessToken) -> bool {
!token.is_revoked && chrono::Utc::now() < token.expires_at
}
-
+
fn reason_for_failure(&self, token: &AccessToken) -> Option<String> {
if token.is_revoked {
Some("Token has been revoked".to_string())
@@ -191,4 +199,4 @@ pub trait SpecificationExt<T>: Specification<T> + Sized + 'static {
}
}
-impl<T, S: Specification<T> + 'static> SpecificationExt<T> for S {} \ No newline at end of file
+impl<T, S: Specification<T> + 'static> SpecificationExt<T> for S {}
diff --git a/src/domain/unit_of_work.rs b/src/domain/unit_of_work.rs
index db8294a..7d6a0e3 100644
--- a/src/domain/unit_of_work.rs
+++ b/src/domain/unit_of_work.rs
@@ -11,10 +11,10 @@ pub trait UnitOfWork: Send + Sync {
pub trait Transaction: Send + Sync {
/// Commit all changes in this transaction
fn commit(self: Box<Self>) -> Result<()>;
-
+
/// Rollback all changes in this transaction
fn rollback(self: Box<Self>) -> Result<()>;
-
+
/// Get repositories within this transaction context
fn client_repository(&self) -> Arc<dyn crate::domain::DomainClientRepository>;
fn auth_code_repository(&self) -> Arc<dyn crate::domain::DomainAuthCodeRepository>;
@@ -31,7 +31,7 @@ impl OAuthUnitOfWork {
pub fn new(uow: Arc<dyn UnitOfWork>) -> Self {
Self { uow }
}
-
+
/// Execute OAuth2 authorization code exchange atomically
pub fn exchange_authorization_code<F>(&self, operation: F) -> Result<()>
where
@@ -46,7 +46,7 @@ impl OAuthUnitOfWork {
}
}
}
-
+
/// Execute token refresh atomically
pub fn refresh_tokens<F>(&self, operation: F) -> Result<()>
where
@@ -61,4 +61,4 @@ impl OAuthUnitOfWork {
}
}
}
-} \ No newline at end of file
+}
diff --git a/src/lib.rs b/src/lib.rs
index e1c1b97..882b65f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,6 +5,7 @@ pub mod database;
pub mod domain;
pub mod http;
pub mod keys;
+pub mod migration_discovery;
pub mod migrations;
pub mod oauth;
pub mod repositories;
diff --git a/src/migration_discovery.rs b/src/migration_discovery.rs
new file mode 100644
index 0000000..fabc660
--- /dev/null
+++ b/src/migration_discovery.rs
@@ -0,0 +1,296 @@
+use anyhow::{Result, anyhow};
+use chrono::Utc;
+use std::collections::BTreeMap;
+use std::path::Path;
+
+#[derive(Debug, Clone)]
+pub struct Migration {
+ pub version: i64,
+ pub name: String,
+ pub sql: String,
+}
+
+/// Migration discovery that reads migration files from the filesystem at runtime
+pub struct RuntimeMigrationDiscovery {
+ migrations_dir: std::path::PathBuf,
+ migrations: BTreeMap<i64, Migration>,
+}
+
+impl RuntimeMigrationDiscovery {
+ pub fn new<P: AsRef<Path>>(migrations_dir: P) -> Result<Self> {
+ let migrations_dir = migrations_dir.as_ref().to_path_buf();
+ let mut discovery = Self {
+ migrations_dir,
+ migrations: BTreeMap::new(),
+ };
+ discovery.discover_migrations()?;
+ Ok(discovery)
+ }
+
+ fn discover_migrations(&mut self) -> Result<()> {
+ if !self.migrations_dir.exists() {
+ return Err(anyhow!(
+ "Migrations directory does not exist: {:?}",
+ self.migrations_dir
+ ));
+ }
+
+ let entries = std::fs::read_dir(&self.migrations_dir)?;
+
+ for entry in entries {
+ let entry = entry?;
+ let path = entry.path();
+
+ if path.is_file() && path.extension().map_or(false, |ext| ext == "sql") {
+ let migration = self.parse_migration_file(&path)?;
+ self.migrations.insert(migration.version, migration);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn parse_migration_file(&self, path: &Path) -> Result<Migration> {
+ let file_name = path
+ .file_name()
+ .and_then(|n| n.to_str())
+ .ok_or_else(|| anyhow!("Invalid migration file name: {:?}", path))?;
+
+ let (version, name) = parse_migration_filename(file_name)?;
+ let sql = std::fs::read_to_string(path)?;
+
+ Ok(Migration { version, name, sql })
+ }
+
+ pub fn get_migrations(&self) -> Vec<&Migration> {
+ self.migrations.values().collect()
+ }
+
+ pub fn get_migration(&self, version: i64) -> Option<&Migration> {
+ self.migrations.get(&version)
+ }
+
+ pub fn get_pending_migrations(&self, current_version: i64) -> Vec<&Migration> {
+ self.migrations
+ .values()
+ .filter(|m| m.version > current_version)
+ .collect()
+ }
+
+ pub fn refresh(&mut self) -> Result<()> {
+ self.migrations.clear();
+ self.discover_migrations()
+ }
+}
+
+/// Parse migration filename to extract timestamp and name
+/// Expected format: "20231201123456_initial_schema.sql" -> (20231201123456, "initial_schema")
+fn parse_migration_filename(filename: &str) -> Result<(i64, String)> {
+ if !filename.ends_with(".sql") {
+ return Err(anyhow!(
+ "Migration file must have .sql extension: {}",
+ filename
+ ));
+ }
+
+ let name_without_ext = &filename[..filename.len() - 4];
+
+ // Find first underscore
+ let underscore_pos = name_without_ext.find('_').ok_or_else(|| {
+ anyhow!(
+ "Migration filename must be in format 'YYYYMMDDHHMMSS_name.sql': {}",
+ filename
+ )
+ })?;
+
+ let timestamp_str = &name_without_ext[..underscore_pos];
+ let name = &name_without_ext[underscore_pos + 1..];
+
+ // Validate timestamp format (should be 14 digits for YYYYMMDDHHMMSS)
+ if timestamp_str.len() != 14 {
+ return Err(anyhow!(
+ "Timestamp must be 14 digits (YYYYMMDDHHMMSS) in migration filename: {}",
+ filename
+ ));
+ }
+
+ let timestamp = timestamp_str
+ .parse::<i64>()
+ .map_err(|_| anyhow!("Invalid timestamp in migration filename: {}", filename))?;
+
+ // Basic validation of timestamp format
+ let year = timestamp / 10000000000;
+ let month = (timestamp / 100000000) % 100;
+ let day = (timestamp / 1000000) % 100;
+
+ if year < 2000 || year > 3000 {
+ return Err(anyhow!("Invalid year in timestamp: {}", timestamp_str));
+ }
+ if month < 1 || month > 12 {
+ return Err(anyhow!("Invalid month in timestamp: {}", timestamp_str));
+ }
+ if day < 1 || day > 31 {
+ return Err(anyhow!("Invalid day in timestamp: {}", timestamp_str));
+ }
+
+ Ok((timestamp, name.to_string()))
+}
+
+/// Generate a timestamp prefix for a new migration file
+/// Returns a string in format YYYYMMDDHHMMSS
+pub fn generate_migration_timestamp() -> String {
+ let now = Utc::now();
+ now.format("%Y%m%d%H%M%S").to_string()
+}
+
+/// Generate a full migration filename with timestamp prefix
+/// Example: generate_migration_filename("add_users_table") -> "20231201123456_add_users_table.sql"
+pub fn generate_migration_filename(name: &str) -> String {
+ format!("{}_{}.sql", generate_migration_timestamp(), name)
+}
+
+/// Trait for migration discovery to allow different implementations
+pub trait MigrationDiscovery {
+ fn get_migrations(&self) -> Vec<&Migration>;
+ fn get_pending_migrations(&self, current_version: i64) -> Vec<&Migration>;
+}
+
+impl MigrationDiscovery for RuntimeMigrationDiscovery {
+ fn get_migrations(&self) -> Vec<&Migration> {
+ self.get_migrations()
+ }
+
+ fn get_pending_migrations(&self, current_version: i64) -> Vec<&Migration> {
+ self.get_pending_migrations(current_version)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::fs;
+ use tempfile::TempDir;
+
+ #[test]
+ fn test_parse_migration_filename() {
+ let cases = vec![
+ (
+ "20231201123456_initial_schema.sql",
+ Ok((20231201123456, "initial_schema".to_string())),
+ ),
+ (
+ "20231202140000_add_users_table.sql",
+ Ok((20231202140000, "add_users_table".to_string())),
+ ),
+ (
+ "20240315091530_complex_migration_name.sql",
+ Ok((20240315091530, "complex_migration_name".to_string())),
+ ),
+ ("invalid.sql", Err(())), // No underscore
+ ("abc_invalid.sql", Err(())), // Non-numeric timestamp
+ ("20231201123456_no_extension", Err(())), // No .sql extension
+ ("123_too_short.sql", Err(())), // Timestamp too short
+ ("202312011234567_too_long.sql", Err(())), // Timestamp too long
+ ("20231300123456_invalid_month.sql", Err(())), // Invalid month
+ ("20231232123456_invalid_day.sql", Err(())), // Invalid day
+ ("19991201123456_invalid_year.sql", Err(())), // Year too old
+ ];
+
+ for (input, expected) in cases {
+ let result = parse_migration_filename(input);
+ match expected {
+ Ok((expected_timestamp, expected_name)) => {
+ assert!(
+ result.is_ok(),
+ "Expected success for {}, got error: {:?}",
+ input,
+ result
+ );
+ let (timestamp, name) = result.unwrap();
+ assert_eq!(
+ timestamp, expected_timestamp,
+ "Timestamp mismatch for {}",
+ input
+ );
+ assert_eq!(name, expected_name, "Name mismatch for {}", input);
+ }
+ Err(_) => {
+ assert!(
+ result.is_err(),
+ "Expected error for {}, got success: {:?}",
+ input,
+ result
+ );
+ }
+ }
+ }
+ }
+
+ #[test]
+ fn test_runtime_migration_discovery() {
+ // Create a temporary directory with migration files
+ let temp_dir = TempDir::new().unwrap();
+ let migrations_dir = temp_dir.path();
+
+ // Create test migration files with timestamp format
+ fs::write(
+ migrations_dir.join("20231201120000_initial_schema.sql"),
+ "CREATE TABLE users (id INTEGER PRIMARY KEY);",
+ )
+ .unwrap();
+
+ fs::write(
+ migrations_dir.join("20231201130000_add_posts.sql"),
+ "CREATE TABLE posts (id INTEGER PRIMARY KEY, user_id INTEGER);",
+ )
+ .unwrap();
+
+ // Test discovery
+ let discovery = RuntimeMigrationDiscovery::new(migrations_dir).unwrap();
+ let migrations = discovery.get_migrations();
+
+ assert_eq!(migrations.len(), 2);
+
+ // Migrations should be sorted by timestamp
+ assert_eq!(migrations[0].version, 20231201120000);
+ assert_eq!(migrations[0].name, "initial_schema");
+ assert_eq!(migrations[1].version, 20231201130000);
+ assert_eq!(migrations[1].name, "add_posts");
+ }
+
+ #[test]
+ fn test_pending_migrations() {
+ // Create a temporary directory with migration files
+ let temp_dir = TempDir::new().unwrap();
+ let migrations_dir = temp_dir.path();
+
+ fs::write(
+ migrations_dir.join("20231201120000_initial_schema.sql"),
+ "CREATE TABLE users (id INTEGER PRIMARY KEY);",
+ )
+ .unwrap();
+
+ fs::write(
+ migrations_dir.join("20231201130000_add_posts.sql"),
+ "CREATE TABLE posts (id INTEGER PRIMARY KEY, user_id INTEGER);",
+ )
+ .unwrap();
+
+ let discovery = RuntimeMigrationDiscovery::new(migrations_dir).unwrap();
+
+ // No pending migrations if we're at the latest timestamp
+ let pending = discovery.get_pending_migrations(20231201130000);
+ assert!(pending.is_empty());
+
+ // All migrations are pending if we're at timestamp 0
+ let pending = discovery.get_pending_migrations(0);
+ assert_eq!(pending.len(), 2);
+ assert_eq!(pending[0].version, 20231201120000);
+ assert_eq!(pending[1].version, 20231201130000);
+
+ // One migration pending if we're at the first timestamp
+ let pending = discovery.get_pending_migrations(20231201120000);
+ assert_eq!(pending.len(), 1);
+ assert_eq!(pending[0].version, 20231201130000);
+ }
+}
diff --git a/src/migrations.rs b/src/migrations.rs
index 61c5b19..021d525 100644
--- a/src/migrations.rs
+++ b/src/migrations.rs
@@ -1,40 +1,31 @@
+use crate::migration_discovery::{Migration, RuntimeMigrationDiscovery};
use anyhow::Result;
use rusqlite::Connection;
-
-pub struct Migration {
- pub version: i32,
- pub name: &'static str,
- pub sql: &'static str,
-}
-
-const MIGRATIONS: &[Migration] = &[
- Migration {
- version: 1,
- name: "initial_schema",
- sql: include_str!("../migrations/001_initial_schema.sql"),
- },
- // Add more migrations here as needed
- // Migration {
- // version: 2,
- // name: "add_user_table",
- // sql: include_str!("../migrations/002_add_user_table.sql"),
- // },
-];
+use std::path::Path;
pub struct MigrationRunner<'a> {
conn: &'a Connection,
+ discovery: RuntimeMigrationDiscovery,
}
impl<'a> MigrationRunner<'a> {
- pub fn new(conn: &'a Connection) -> Self {
- Self { conn }
+ pub fn new(conn: &'a Connection) -> Result<Self> {
+ // Default to migrations directory relative to project root
+ let migrations_dir = Path::new("migrations");
+ let discovery = RuntimeMigrationDiscovery::new(migrations_dir)?;
+ Ok(Self { conn, discovery })
+ }
+
+ pub fn new_with_path<P: AsRef<Path>>(conn: &'a Connection, migrations_dir: P) -> Result<Self> {
+ let discovery = RuntimeMigrationDiscovery::new(migrations_dir)?;
+ Ok(Self { conn, discovery })
}
pub fn run_migrations(&self) -> Result<()> {
// Create migrations table if it doesn't exist
self.conn.execute(
"CREATE TABLE IF NOT EXISTS schema_migrations (
- version INTEGER PRIMARY KEY,
+ version BIGINT PRIMARY KEY,
name TEXT NOT NULL,
applied_at TEXT NOT NULL
)",
@@ -46,22 +37,23 @@ impl<'a> MigrationRunner<'a> {
println!("Current database version: {}", current_version);
- // Run pending migrations
- for migration in MIGRATIONS {
- if migration.version > current_version {
- println!(
- "Running migration {}: {}",
- migration.version, migration.name
- );
- self.run_migration(migration)?;
- }
+ // Get pending migrations from discovery system
+ let pending_migrations = self.discovery.get_pending_migrations(current_version);
+
+ // Run each pending migration
+ for migration in pending_migrations {
+ println!(
+ "Running migration {}: {}",
+ migration.version, migration.name
+ );
+ self.run_migration(migration)?;
}
println!("All migrations completed successfully");
Ok(())
}
- fn get_current_version(&self) -> Result<i32> {
+ fn get_current_version(&self) -> Result<i64> {
// Check if schema_migrations table exists first
let table_exists = self.conn.query_row(
"SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
@@ -75,7 +67,7 @@ impl<'a> MigrationRunner<'a> {
let version = self.conn.query_row(
"SELECT COALESCE(MAX(version), 0) FROM schema_migrations",
[],
- |row| row.get::<_, i32>(0),
+ |row| row.get::<_, i64>(0),
)?;
Ok(version)
}
@@ -88,14 +80,14 @@ impl<'a> MigrationRunner<'a> {
fn run_migration(&self, migration: &Migration) -> Result<()> {
// Execute the migration SQL
- self.conn.execute_batch(migration.sql)?;
+ self.conn.execute_batch(&migration.sql)?;
// Record the migration as applied
self.conn.execute(
"INSERT INTO schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)",
[
&migration.version.to_string(),
- migration.name,
+ &migration.name,
&chrono::Utc::now().to_rfc3339(),
],
)?;
@@ -103,7 +95,7 @@ impl<'a> MigrationRunner<'a> {
Ok(())
}
- pub fn rollback_to_version(&self, target_version: i32) -> Result<()> {
+ pub fn rollback_to_version(&self, target_version: i64) -> Result<()> {
println!("Rolling back to version {}", target_version);
// This is a simplified rollback - in practice you'd need down migrations
@@ -127,7 +119,7 @@ impl<'a> MigrationRunner<'a> {
let migrations = stmt.query_map([], |row| {
Ok((
- row.get::<_, i32>(0)?,
+ row.get::<_, i64>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
@@ -143,13 +135,12 @@ impl<'a> MigrationRunner<'a> {
// Show pending migrations
let current_version = self.get_current_version()?;
- for migration in MIGRATIONS {
- if migration.version > current_version {
- println!(
- "⏳ Migration {}: {} (pending)",
- migration.version, migration.name
- );
- }
+ let pending_migrations = self.discovery.get_pending_migrations(current_version);
+ for migration in pending_migrations {
+ println!(
+ "⏳ Migration {}: {} (pending)",
+ migration.version, migration.name
+ );
}
Ok(())
@@ -162,8 +153,21 @@ mod tests {
#[test]
fn test_migration_runner() {
+ use std::fs;
+ use tempfile::TempDir;
+
+ // Create a temporary directory with a test migration
+ let temp_dir = TempDir::new().unwrap();
+ let migrations_dir = temp_dir.path();
+
+ fs::write(
+ migrations_dir.join("20231201120000_test_migration.sql"),
+ "CREATE TABLE test_table (id INTEGER PRIMARY KEY);",
+ )
+ .unwrap();
+
let conn = Connection::open_in_memory().unwrap();
- let runner = MigrationRunner::new(&conn);
+ let runner = MigrationRunner::new_with_path(&conn, migrations_dir).unwrap();
// Should start with version 0
assert_eq!(runner.get_current_version().unwrap(), 0);
@@ -171,10 +175,7 @@ mod tests {
// Run migrations
runner.run_migrations().unwrap();
- // Should now be at latest version
- assert_eq!(
- runner.get_current_version().unwrap(),
- MIGRATIONS.len() as i32
- );
+ // Should now be at latest version (timestamp of our migration)
+ assert_eq!(runner.get_current_version().unwrap(), 20231201120000);
}
}