diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 14:06:07 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 14:06:07 -0600 |
| commit | d612e590cb4f5b633abc316f2e105924226a7d6f (patch) | |
| tree | 84103cfbc80099745fedfa2d9f22118dbf539eb5 /src | |
| parent | 6abae9d4b410bad780635f361d183d043089cf57 (diff) | |
Add database migrations
Diffstat (limited to 'src')
| -rw-r--r-- | src/bin/migrate.rs | 61 | ||||
| -rw-r--r-- | src/lib.rs | 2 | ||||
| -rw-r--r-- | src/migrations.rs | 152 |
3 files changed, 215 insertions, 0 deletions
diff --git a/src/bin/migrate.rs b/src/bin/migrate.rs new file mode 100644 index 0000000..9afdbf0 --- /dev/null +++ b/src/bin/migrate.rs @@ -0,0 +1,61 @@ +use anyhow::Result; +use rusqlite::Connection; +use sts::{Config, MigrationRunner}; +use std::env; + +fn main() -> Result<()> { + let args: Vec<String> = env::args().collect(); + + if args.len() < 2 { + print_usage(); + return Ok(()); + } + + let config = Config::from_env(); + let conn = Connection::open(&config.database_path)?; + let runner = MigrationRunner::new(&conn); + + match args[1].as_str() { + "up" => { + println!("Running migrations..."); + runner.run_migrations()?; + } + "status" => { + runner.show_migration_status()?; + } + "rollback" => { + if args.len() < 3 { + eprintln!("Error: rollback requires a version number"); + eprintln!("Usage: cargo run --bin migrate rollback <version>"); + return Ok(()); + } + let version: i32 = args[2].parse() + .map_err(|_| anyhow::anyhow!("Invalid version number: {}", args[2]))?; + runner.rollback_to_version(version)?; + } + _ => { + eprintln!("Error: Unknown command '{}'", args[1]); + print_usage(); + } + } + + Ok(()) +} + +fn print_usage() { + println!("OAuth2 STS Migration Tool"); + println!("========================"); + println!(""); + println!("Usage:"); + println!(" cargo run --bin migrate up # Run pending migrations"); + println!(" cargo run --bin migrate status # Show migration status"); + println!(" cargo run --bin migrate rollback <version> # Rollback to version"); + println!(""); + println!("Environment Variables:"); + println!(" DATABASE_PATH Path to SQLite database (default: oauth.db)"); + println!(""); + println!("Examples:"); + println!(" cargo run --bin migrate up"); + println!(" cargo run --bin migrate status"); + println!(" cargo run --bin migrate rollback 0"); +}
\ No newline at end of file @@ -3,10 +3,12 @@ pub mod config; pub mod database; pub mod http; pub mod keys; +pub mod migrations; pub mod oauth; pub use clients::ClientManager; pub use config::Config; pub use database::Database; pub use http::Server; +pub use migrations::MigrationRunner; pub use oauth::OAuthServer; diff --git a/src/migrations.rs b/src/migrations.rs new file mode 100644 index 0000000..c7cd6bf --- /dev/null +++ b/src/migrations.rs @@ -0,0 +1,152 @@ +use anyhow::Result; +use rusqlite::Connection; + +pub struct Migration { + pub version: i32, + pub name: &'static str, + pub sql: &'static str, +} + +const MIGRATIONS: &[Migration] = &[ + Migration { + version: 1, + name: "initial_schema", + sql: include_str!("../migrations/001_initial_schema.sql"), + }, + // Add more migrations here as needed + // Migration { + // version: 2, + // name: "add_user_table", + // sql: include_str!("../migrations/002_add_user_table.sql"), + // }, +]; + +pub struct MigrationRunner<'a> { + conn: &'a Connection, +} + +impl<'a> MigrationRunner<'a> { + pub fn new(conn: &'a Connection) -> Self { + Self { conn } + } + + pub fn run_migrations(&self) -> Result<()> { + // Create migrations table if it doesn't exist + self.conn.execute( + "CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + name TEXT NOT NULL, + applied_at TEXT NOT NULL + )", + [], + )?; + + // Get current migration version + let current_version = self.get_current_version()?; + + println!("Current database version: {}", current_version); + + // Run pending migrations + for migration in MIGRATIONS { + if migration.version > current_version { + println!("Running migration {}: {}", migration.version, migration.name); + self.run_migration(migration)?; + } + } + + println!("All migrations completed successfully"); + Ok(()) + } + + fn get_current_version(&self) -> Result<i32> { + let version = self.conn.query_row( + "SELECT COALESCE(MAX(version), 0) FROM schema_migrations", + [], + |row| row.get::<_, i32>(0), + )?; + Ok(version) + } + + fn run_migration(&self, migration: &Migration) -> Result<()> { + // Execute the migration SQL + self.conn.execute_batch(migration.sql)?; + + // Record the migration as applied + self.conn.execute( + "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)", + [ + &migration.version.to_string(), + migration.name, + &chrono::Utc::now().to_rfc3339(), + ], + )?; + + Ok(()) + } + + pub fn rollback_to_version(&self, target_version: i32) -> Result<()> { + println!("Rolling back to version {}", target_version); + + // This is a simplified rollback - in practice you'd need down migrations + // For now, just remove migration records + self.conn.execute( + "DELETE FROM schema_migrations WHERE version > ?1", + [target_version], + )?; + + println!("Rollback completed (Note: This doesn't actually undo schema changes)"); + Ok(()) + } + + pub fn show_migration_status(&self) -> Result<()> { + println!("Migration Status:"); + println!("================"); + + let mut stmt = self.conn.prepare( + "SELECT version, name, applied_at FROM schema_migrations ORDER BY version" + )?; + + let migrations = stmt.query_map([], |row| { + Ok(( + row.get::<_, i32>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + )) + })?; + + for migration in migrations { + let (version, name, applied_at) = migration?; + println!("✅ Migration {}: {} (applied: {})", version, name, applied_at); + } + + // Show pending migrations + let current_version = self.get_current_version()?; + for migration in MIGRATIONS { + if migration.version > current_version { + println!("⏳ Migration {}: {} (pending)", migration.version, migration.name); + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_migration_runner() { + let conn = Connection::open_in_memory().unwrap(); + let runner = MigrationRunner::new(&conn); + + // Should start with version 0 + assert_eq!(runner.get_current_version().unwrap(), 0); + + // Run migrations + runner.run_migrations().unwrap(); + + // Should now be at latest version + assert_eq!(runner.get_current_version().unwrap(), MIGRATIONS.len() as i32); + } +}
\ No newline at end of file |
