summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 14:06:07 -0600
committermo khan <mo@mokhan.ca>2025-06-11 14:06:07 -0600
commitd612e590cb4f5b633abc316f2e105924226a7d6f (patch)
tree84103cfbc80099745fedfa2d9f22118dbf539eb5
parent6abae9d4b410bad780635f361d183d043089cf57 (diff)
Add database migrations
-rw-r--r--Cargo.toml8
-rw-r--r--oauth.dbbin98304 -> 102400 bytes
-rw-r--r--src/bin/migrate.rs61
-rw-r--r--src/lib.rs2
-rw-r--r--src/migrations.rs152
5 files changed, 223 insertions, 0 deletions
diff --git a/Cargo.toml b/Cargo.toml
index 18b16c1..3ab38b4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -3,6 +3,14 @@ name = "sts"
version = "0.1.0"
edition = "2024"
+[[bin]]
+name = "sts"
+path = "src/main.rs"
+
+[[bin]]
+name = "migrate"
+path = "src/bin/migrate.rs"
+
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
diff --git a/oauth.db b/oauth.db
index 1896710..edb0e88 100644
--- a/oauth.db
+++ b/oauth.db
Binary files differ
diff --git a/src/bin/migrate.rs b/src/bin/migrate.rs
new file mode 100644
index 0000000..9afdbf0
--- /dev/null
+++ b/src/bin/migrate.rs
@@ -0,0 +1,61 @@
+use anyhow::Result;
+use rusqlite::Connection;
+use sts::{Config, MigrationRunner};
+use std::env;
+
+fn main() -> Result<()> {
+ let args: Vec<String> = env::args().collect();
+
+ if args.len() < 2 {
+ print_usage();
+ return Ok(());
+ }
+
+ let config = Config::from_env();
+ let conn = Connection::open(&config.database_path)?;
+ let runner = MigrationRunner::new(&conn);
+
+ match args[1].as_str() {
+ "up" => {
+ println!("Running migrations...");
+ runner.run_migrations()?;
+ }
+ "status" => {
+ runner.show_migration_status()?;
+ }
+ "rollback" => {
+ if args.len() < 3 {
+ eprintln!("Error: rollback requires a version number");
+ eprintln!("Usage: cargo run --bin migrate rollback <version>");
+ return Ok(());
+ }
+ let version: i32 = args[2].parse()
+ .map_err(|_| anyhow::anyhow!("Invalid version number: {}", args[2]))?;
+ runner.rollback_to_version(version)?;
+ }
+ _ => {
+ eprintln!("Error: Unknown command '{}'", args[1]);
+ print_usage();
+ }
+ }
+
+ Ok(())
+}
+
+fn print_usage() {
+ println!("OAuth2 STS Migration Tool");
+ println!("========================");
+ println!("");
+ println!("Usage:");
+ println!(" cargo run --bin migrate up # Run pending migrations");
+ println!(" cargo run --bin migrate status # Show migration status");
+ println!(" cargo run --bin migrate rollback <version> # Rollback to version");
+ println!("");
+ println!("Environment Variables:");
+ println!(" DATABASE_PATH Path to SQLite database (default: oauth.db)");
+ println!("");
+ println!("Examples:");
+ println!(" cargo run --bin migrate up");
+ println!(" cargo run --bin migrate status");
+ println!(" cargo run --bin migrate rollback 0");
+} \ No newline at end of file
diff --git a/src/lib.rs b/src/lib.rs
index 0ab228e..eef2cbf 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,10 +3,12 @@ pub mod config;
pub mod database;
pub mod http;
pub mod keys;
+pub mod migrations;
pub mod oauth;
pub use clients::ClientManager;
pub use config::Config;
pub use database::Database;
pub use http::Server;
+pub use migrations::MigrationRunner;
pub use oauth::OAuthServer;
diff --git a/src/migrations.rs b/src/migrations.rs
new file mode 100644
index 0000000..c7cd6bf
--- /dev/null
+++ b/src/migrations.rs
@@ -0,0 +1,152 @@
+use anyhow::Result;
+use rusqlite::Connection;
+
+pub struct Migration {
+ pub version: i32,
+ pub name: &'static str,
+ pub sql: &'static str,
+}
+
+const MIGRATIONS: &[Migration] = &[
+ Migration {
+ version: 1,
+ name: "initial_schema",
+ sql: include_str!("../migrations/001_initial_schema.sql"),
+ },
+ // Add more migrations here as needed
+ // Migration {
+ // version: 2,
+ // name: "add_user_table",
+ // sql: include_str!("../migrations/002_add_user_table.sql"),
+ // },
+];
+
+pub struct MigrationRunner<'a> {
+ conn: &'a Connection,
+}
+
+impl<'a> MigrationRunner<'a> {
+ pub fn new(conn: &'a Connection) -> Self {
+ Self { conn }
+ }
+
+ pub fn run_migrations(&self) -> Result<()> {
+ // Create migrations table if it doesn't exist
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS schema_migrations (
+ version INTEGER PRIMARY KEY,
+ name TEXT NOT NULL,
+ applied_at TEXT NOT NULL
+ )",
+ [],
+ )?;
+
+ // Get current migration version
+ let current_version = self.get_current_version()?;
+
+ println!("Current database version: {}", current_version);
+
+ // Run pending migrations
+ for migration in MIGRATIONS {
+ if migration.version > current_version {
+ println!("Running migration {}: {}", migration.version, migration.name);
+ self.run_migration(migration)?;
+ }
+ }
+
+ println!("All migrations completed successfully");
+ Ok(())
+ }
+
+ fn get_current_version(&self) -> Result<i32> {
+ let version = self.conn.query_row(
+ "SELECT COALESCE(MAX(version), 0) FROM schema_migrations",
+ [],
+ |row| row.get::<_, i32>(0),
+ )?;
+ Ok(version)
+ }
+
+ fn run_migration(&self, migration: &Migration) -> Result<()> {
+ // Execute the migration SQL
+ self.conn.execute_batch(migration.sql)?;
+
+ // Record the migration as applied
+ self.conn.execute(
+ "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)",
+ [
+ &migration.version.to_string(),
+ migration.name,
+ &chrono::Utc::now().to_rfc3339(),
+ ],
+ )?;
+
+ Ok(())
+ }
+
+ pub fn rollback_to_version(&self, target_version: i32) -> Result<()> {
+ println!("Rolling back to version {}", target_version);
+
+ // This is a simplified rollback - in practice you'd need down migrations
+ // For now, just remove migration records
+ self.conn.execute(
+ "DELETE FROM schema_migrations WHERE version > ?1",
+ [target_version],
+ )?;
+
+ println!("Rollback completed (Note: This doesn't actually undo schema changes)");
+ Ok(())
+ }
+
+ pub fn show_migration_status(&self) -> Result<()> {
+ println!("Migration Status:");
+ println!("================");
+
+ let mut stmt = self.conn.prepare(
+ "SELECT version, name, applied_at FROM schema_migrations ORDER BY version"
+ )?;
+
+ let migrations = stmt.query_map([], |row| {
+ Ok((
+ row.get::<_, i32>(0)?,
+ row.get::<_, String>(1)?,
+ row.get::<_, String>(2)?,
+ ))
+ })?;
+
+ for migration in migrations {
+ let (version, name, applied_at) = migration?;
+ println!("✅ Migration {}: {} (applied: {})", version, name, applied_at);
+ }
+
+ // Show pending migrations
+ let current_version = self.get_current_version()?;
+ for migration in MIGRATIONS {
+ if migration.version > current_version {
+ println!("⏳ Migration {}: {} (pending)", migration.version, migration.name);
+ }
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_migration_runner() {
+ let conn = Connection::open_in_memory().unwrap();
+ let runner = MigrationRunner::new(&conn);
+
+ // Should start with version 0
+ assert_eq!(runner.get_current_version().unwrap(), 0);
+
+ // Run migrations
+ runner.run_migrations().unwrap();
+
+ // Should now be at latest version
+ assert_eq!(runner.get_current_version().unwrap(), MIGRATIONS.len() as i32);
+ }
+} \ No newline at end of file