diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 15:15:41 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 15:15:41 -0600 |
| commit | aea6bd6ec7d7e70a67723edf6327df4a9cc65d89 (patch) | |
| tree | 80fcb6cbda7baa5ed15cf044d7583acb2438c4d2 | |
| parent | 4435ee26b79648e92d0f172e42f9e6629e955505 (diff) | |
chore: run rustfmt again
| -rw-r--r-- | src/clients.rs | 112 | ||||
| -rw-r--r-- | src/config.rs | 6 | ||||
| -rw-r--r-- | src/database.rs | 146 | ||||
| -rw-r--r-- | src/keys.rs | 28 | ||||
| -rw-r--r-- | src/main.rs | 49 | ||||
| -rw-r--r-- | src/migrations.rs | 46 |
6 files changed, 247 insertions, 140 deletions
diff --git a/src/clients.rs b/src/clients.rs index bc9aea5..7941d8c 100644 --- a/src/clients.rs +++ b/src/clients.rs @@ -1,12 +1,12 @@ +use crate::database::{Database, DbOAuthClient}; +use anyhow::Result; use base64::Engine; +use chrono::Utc; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; use uuid::Uuid; -use chrono::Utc; -use crate::database::{Database, DbOAuthClient}; -use anyhow::Result; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OAuthClient { @@ -37,10 +37,10 @@ impl ClientManager { clients: HashMap::new(), database: database.clone(), }; - + // Load existing clients from database into cache manager.load_clients_from_db()?; - + // Register a default test client for development if it doesn't exist if manager.get_client_from_db("test_client")?.is_none() { let _ = manager.register_client( @@ -48,13 +48,17 @@ impl ClientManager { "test_secret".to_string(), vec!["http://localhost:3000/callback".to_string()], "Test Client".to_string(), - vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + vec![ + "openid".to_string(), + "profile".to_string(), + "email".to_string(), + ], ); // Ignore errors if client already exists } - + Ok(manager) } - + fn load_clients_from_db(&mut self) -> Result<()> { // This is a simplified version - in practice you'd want to load all clients // For now we'll load on-demand @@ -129,40 +133,52 @@ impl ClientManager { if let Some(client) = self.clients.get(client_id) { return Some(client); } - + // If not in cache, try to load from database // For thread safety, we can't mutate self here, so we'll return None // In a real implementation, you'd want a more sophisticated caching strategy None } - + pub fn get_client_from_db(&mut self, client_id: &str) -> Result<Option<OAuthClient>> { // Check cache first if let Some(client) = self.clients.get(client_id) { return Ok(Some(client.clone())); } - + // Load from database let db_client = { let db = self.database.lock().unwrap(); db.get_oauth_client(client_id)? }; - + if let Some(db_client) = db_client { let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?; - let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect(); - + let scopes: Vec<String> = db_client + .scopes + .split_whitespace() + .map(|s| s.to_string()) + .collect(); + let client = OAuthClient { client_id: db_client.client_id.clone(), client_secret_hash: db_client.client_secret_hash, redirect_uris, client_name: db_client.client_name, scopes, - grant_types: db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect(), - response_types: db_client.response_types.split_whitespace().map(|s| s.to_string()).collect(), + grant_types: db_client + .grant_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(), + response_types: db_client + .response_types + .split_whitespace() + .map(|s| s.to_string()) + .collect(), created_at: db_client.created_at.timestamp() as u64, }; - + // Cache it self.clients.insert(db_client.client_id, client.clone()); Ok(Some(client)) @@ -181,7 +197,7 @@ impl ClientManager { return false; } }; - + let provided_hash = Self::hash_secret(client_secret); // Use constant-time comparison to prevent timing attacks self.constant_time_eq(&client.client_secret_hash, &provided_hash) @@ -199,7 +215,9 @@ impl ClientManager { if let Ok(Some(client)) = self.get_client_from_db(client_id) { if let Some(scopes_str) = requested_scopes { let requested: Vec<&str> = scopes_str.split_whitespace().collect(); - requested.iter().all(|scope| client.scopes.contains(&scope.to_string())) + requested + .iter() + .all(|scope| client.scopes.contains(&scope.to_string())) } else { true // No scopes requested is valid } @@ -224,7 +242,7 @@ impl ClientManager { if a.len() != b.len() { return false; } - + let mut result = 0u8; for (byte_a, byte_b) in a.bytes().zip(b.bytes()) { result |= byte_a ^ byte_b; @@ -232,16 +250,24 @@ impl ClientManager { result == 0 } - pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec<String>) -> Result<ClientCredentials> { + pub fn generate_client_credentials( + &mut self, + client_name: String, + redirect_uris: Vec<String>, + ) -> Result<ClientCredentials> { let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", "")); let client_secret = Uuid::new_v4().to_string(); - + self.register_client( client_id.clone(), client_secret.clone(), redirect_uris, client_name, - vec!["openid".to_string(), "profile".to_string(), "email".to_string()], + vec![ + "openid".to_string(), + "profile".to_string(), + "email".to_string(), + ], )?; Ok(ClientCredentials { @@ -253,7 +279,7 @@ impl ClientManager { pub fn list_clients(&self) -> Vec<&OAuthClient> { self.clients.values().collect() } - + pub fn list_all_clients_from_db(&self) -> Result<Vec<DbOAuthClient>> { // This would require a new database method - for now return empty Ok(vec![]) @@ -270,13 +296,13 @@ pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> { let decoded = base64::engine::general_purpose::STANDARD .decode(encoded) .ok()?; - + let credentials = String::from_utf8(decoded).ok()?; let mut parts = credentials.splitn(2, ':'); - + let username = parts.next()?.to_string(); let password = parts.next()?.to_string(); - + Some((username, password)) } @@ -288,7 +314,7 @@ mod disabled_tests { #[test] fn test_client_registration() { let mut manager = ClientManager::new(); - + let result = manager.register_client( "new_client".to_string(), "secret123".to_string(), @@ -296,7 +322,7 @@ mod disabled_tests { "Test App".to_string(), vec!["openid".to_string()], ); - + assert!(result.is_ok()); let client = result.unwrap(); assert_eq!(client.client_id, "new_client"); @@ -306,7 +332,7 @@ mod disabled_tests { #[test] fn test_duplicate_client_id() { let mut manager = ClientManager::new(); - + manager.register_client( "duplicate".to_string(), "secret1".to_string(), @@ -314,7 +340,7 @@ mod disabled_tests { "App 1".to_string(), vec!["openid".to_string()], ).unwrap(); - + let result = manager.register_client( "duplicate".to_string(), "secret2".to_string(), @@ -322,14 +348,14 @@ mod disabled_tests { "App 2".to_string(), vec!["openid".to_string()], ); - + assert!(result.is_err()); } #[test] fn test_client_authentication() { let mut manager = ClientManager::new(); - + manager.register_client( "auth_client".to_string(), "correct_secret".to_string(), @@ -337,7 +363,7 @@ mod disabled_tests { "Auth Test".to_string(), vec!["openid".to_string()], ).unwrap(); - + assert!(manager.authenticate_client("auth_client", "correct_secret")); assert!(!manager.authenticate_client("auth_client", "wrong_secret")); assert!(!manager.authenticate_client("nonexistent", "any_secret")); @@ -346,7 +372,7 @@ mod disabled_tests { #[test] fn test_redirect_uri_validation() { let mut manager = ClientManager::new(); - + manager.register_client( "uri_client".to_string(), "secret".to_string(), @@ -354,7 +380,7 @@ mod disabled_tests { "URI Test".to_string(), vec!["openid".to_string()], ).unwrap(); - + assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback")); assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback")); assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback")); @@ -364,7 +390,7 @@ mod disabled_tests { #[test] fn test_scope_validation() { let mut manager = ClientManager::new(); - + manager.register_client( "scope_client".to_string(), "secret".to_string(), @@ -372,7 +398,7 @@ mod disabled_tests { "Scope Test".to_string(), vec!["openid".to_string(), "profile".to_string()], ).unwrap(); - + assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string()))); assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string()))); assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string()))); @@ -383,7 +409,7 @@ mod disabled_tests { fn test_basic_auth_parsing() { let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret let result = parse_basic_auth(auth_header); - + assert!(result.is_some()); let (username, password) = result.unwrap(); assert_eq!(username, "test_client"); @@ -393,19 +419,19 @@ mod disabled_tests { #[test] fn test_generate_client_credentials() { let mut manager = ClientManager::new(); - + let result = manager.generate_client_credentials( "Generated App".to_string(), vec!["https://generated.com/callback".to_string()], ); - + assert!(result.is_ok()); let credentials = result.unwrap(); assert!(credentials.client_id.starts_with("client_")); assert!(!credentials.client_secret.is_empty()); - + // Verify the client was actually registered assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret)); } } -*/
\ No newline at end of file +*/ diff --git a/src/config.rs b/src/config.rs index 266f669..e496581 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,8 +13,10 @@ pub struct Config { impl Config { pub fn from_env() -> Self { let bind_addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:7878".to_string()); - let issuer_url = std::env::var("ISSUER_URL").unwrap_or_else(|_| format!("http://{}", bind_addr)); - let database_path = std::env::var("DATABASE_PATH").unwrap_or_else(|_| "oauth.db".to_string()); + let issuer_url = + std::env::var("ISSUER_URL").unwrap_or_else(|_| format!("http://{}", bind_addr)); + let database_path = + std::env::var("DATABASE_PATH").unwrap_or_else(|_| "oauth.db".to_string()); let rate_limit_requests_per_minute = std::env::var("RATE_LIMIT_RPM") .unwrap_or_else(|_| "60".to_string()) .parse() diff --git a/src/database.rs b/src/database.rs index dc33cf8..2472d1a 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,6 +1,6 @@ use anyhow::Result; use chrono::{DateTime, Utc}; -use rusqlite::{params, Connection}; +use rusqlite::{Connection, params}; use serde::{Deserialize, Serialize}; use std::path::Path; @@ -10,9 +10,9 @@ pub struct DbOAuthClient { pub client_id: String, pub client_secret_hash: String, pub client_name: String, - pub redirect_uris: String, // JSON array - pub scopes: String, // Space-separated - pub grant_types: String, // Space-separated + pub redirect_uris: String, // JSON array + pub scopes: String, // Space-separated + pub grant_types: String, // Space-separated pub response_types: String, // Space-separated pub created_at: DateTime<Utc>, pub updated_at: DateTime<Utc>, @@ -270,7 +270,7 @@ impl Database { "INSERT INTO oauth_clients (client_id, client_secret_hash, client_name, redirect_uris, scopes, grant_types, response_types, created_at, updated_at, is_active) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)" + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", )?; let id = stmt.insert(params![ @@ -293,7 +293,7 @@ impl Database { let mut stmt = self.conn.prepare( "SELECT id, client_id, client_secret_hash, client_name, redirect_uris, scopes, grant_types, response_types, created_at, updated_at, is_active - FROM oauth_clients WHERE client_id = ?1 AND is_active = 1" + FROM oauth_clients WHERE client_id = ?1 AND is_active = 1", )?; let client = stmt.query_row([client_id], |row| { @@ -307,10 +307,22 @@ impl Database { grant_types: row.get(6)?, response_types: row.get(7)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(8, "created_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 8, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(9, "updated_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 9, + "updated_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), is_active: row.get(10)?, }) @@ -329,7 +341,7 @@ impl Database { "INSERT INTO auth_codes (code, client_id, user_id, redirect_uri, scope, expires_at, created_at, is_used, code_challenge, code_challenge_method) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)" + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", )?; let id = stmt.insert(params![ @@ -352,7 +364,7 @@ impl Database { let mut stmt = self.conn.prepare( "SELECT id, code, client_id, user_id, redirect_uri, scope, expires_at, created_at, is_used, code_challenge, code_challenge_method - FROM auth_codes WHERE code = ?1" + FROM auth_codes WHERE code = ?1", )?; let auth_code = stmt.query_row([code], |row| { @@ -364,10 +376,22 @@ impl Database { redirect_uri: row.get(4)?, scope: row.get(5)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(6, "expires_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 6, + "expires_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(7, "created_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 7, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), is_used: row.get(8)?, code_challenge: row.get(9)?, @@ -383,10 +407,8 @@ impl Database { } pub fn mark_auth_code_used(&self, code: &str) -> Result<()> { - self.conn.execute( - "UPDATE auth_codes SET is_used = 1 WHERE code = ?1", - [code], - )?; + self.conn + .execute("UPDATE auth_codes SET is_used = 1 WHERE code = ?1", [code])?; Ok(()) } @@ -395,7 +417,7 @@ impl Database { let mut stmt = self.conn.prepare( "INSERT INTO access_tokens (token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", )?; let id = stmt.insert(params![ @@ -426,10 +448,22 @@ impl Database { user_id: row.get(3)?, scope: row.get(4)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(5, "expires_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 5, + "expires_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(6, "created_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 6, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), is_revoked: row.get(7)?, token_hash: row.get(8)?, @@ -455,7 +489,7 @@ impl Database { pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result<i64> { let mut stmt = self.conn.prepare( "INSERT INTO rsa_keys (kid, private_key_pem, public_key_pem, created_at, is_current) - VALUES (?1, ?2, ?3, ?4, ?5)" + VALUES (?1, ?2, ?3, ?4, ?5)", )?; let id = stmt.insert(params![ @@ -472,7 +506,7 @@ impl Database { pub fn get_current_rsa_key(&self) -> Result<Option<DbRsaKey>> { let mut stmt = self.conn.prepare( "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current - FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1" + FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1", )?; let key = stmt.query_row([], |row| { @@ -482,7 +516,13 @@ impl Database { private_key_pem: row.get(2)?, public_key_pem: row.get(3)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 4, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), is_current: row.get(5)?, }) @@ -498,7 +538,7 @@ impl Database { pub fn get_all_rsa_keys(&self) -> Result<Vec<DbRsaKey>> { let mut stmt = self.conn.prepare( "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current - FROM rsa_keys ORDER BY created_at DESC" + FROM rsa_keys ORDER BY created_at DESC", )?; let keys = stmt.query_map([], |row| { @@ -508,7 +548,13 @@ impl Database { private_key_pem: row.get(2)?, public_key_pem: row.get(3)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) - .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))? + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 4, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? .with_timezone(&Utc), is_current: row.get(5)?, }) @@ -523,14 +569,13 @@ impl Database { pub fn set_current_rsa_key(&self, kid: &str) -> Result<()> { // First, unset all current keys - self.conn.execute("UPDATE rsa_keys SET is_current = 0", [])?; - + self.conn + .execute("UPDATE rsa_keys SET is_current = 0", [])?; + // Then set the specified key as current - self.conn.execute( - "UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1", - [kid], - )?; - + self.conn + .execute("UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1", [kid])?; + Ok(()) } @@ -539,7 +584,7 @@ impl Database { let mut stmt = self.conn.prepare( "INSERT INTO audit_logs (event_type, client_id, user_id, ip_address, user_agent, details, created_at, success) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", )?; let id = stmt.insert(params![ @@ -557,7 +602,12 @@ impl Database { } // Rate Limiting operations - pub fn increment_rate_limit(&self, identifier: &str, endpoint: &str, window_minutes: i32) -> Result<i32> { + pub fn increment_rate_limit( + &self, + identifier: &str, + endpoint: &str, + window_minutes: i32, + ) -> Result<i32> { let now = Utc::now(); let window_start = now - chrono::Duration::minutes(window_minutes as i64); @@ -630,7 +680,7 @@ mod tests { #[test] fn test_oauth_client_operations() { let db = Database::new_in_memory().expect("Failed to create database"); - + let client = DbOAuthClient { id: 0, client_id: "test_client".to_string(), @@ -645,10 +695,14 @@ mod tests { is_active: true, }; - let id = db.create_oauth_client(&client).expect("Failed to create client"); + let id = db + .create_oauth_client(&client) + .expect("Failed to create client"); assert!(id > 0); - let retrieved = db.get_oauth_client("test_client").expect("Failed to get client"); + let retrieved = db + .get_oauth_client("test_client") + .expect("Failed to get client"); assert!(retrieved.is_some()); assert_eq!(retrieved.unwrap().client_name, "Test Client"); } @@ -671,7 +725,8 @@ mod tests { updated_at: Utc::now(), is_active: true, }; - db.create_oauth_client(&client).expect("Failed to create client"); + db.create_oauth_client(&client) + .expect("Failed to create client"); let auth_code = DbAuthCode { id: 0, @@ -687,17 +742,24 @@ mod tests { code_challenge_method: Some("S256".to_string()), }; - let id = db.create_auth_code(&auth_code).expect("Failed to create auth code"); + let id = db + .create_auth_code(&auth_code) + .expect("Failed to create auth code"); assert!(id > 0); - let retrieved = db.get_auth_code("test_code_123").expect("Failed to get auth code"); + let retrieved = db + .get_auth_code("test_code_123") + .expect("Failed to get auth code"); assert!(retrieved.is_some()); let code = retrieved.unwrap(); assert_eq!(code.client_id, "test_client"); assert_eq!(code.is_used, false); - db.mark_auth_code_used("test_code_123").expect("Failed to mark code as used"); - let updated = db.get_auth_code("test_code_123").expect("Failed to get auth code"); + db.mark_auth_code_used("test_code_123") + .expect("Failed to mark code as used"); + let updated = db + .get_auth_code("test_code_123") + .expect("Failed to get auth code"); assert_eq!(updated.unwrap().is_used, true); } -}
\ No newline at end of file +} diff --git a/src/keys.rs b/src/keys.rs index 16b943c..675eb61 100644 --- a/src/keys.rs +++ b/src/keys.rs @@ -1,6 +1,9 @@ +use crate::database::{Database, DbRsaKey}; +use anyhow::Result; use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; +use chrono::Utc; use jsonwebtoken::{DecodingKey, EncodingKey}; -use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey, DecodePrivateKey, DecodePublicKey}; +use rsa::pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey}; use rsa::traits::PublicKeyParts; use rsa::{RsaPrivateKey, RsaPublicKey}; use serde::Serialize; @@ -8,9 +11,6 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; use uuid::Uuid; -use chrono::Utc; -use crate::database::{Database, DbRsaKey}; -use anyhow::Result; #[derive(Clone)] pub struct KeyPair { @@ -56,28 +56,28 @@ impl KeyManager { // Load existing keys from database manager.load_keys_from_db()?; - + // If no keys exist, generate the first one if manager.keys.is_empty() { manager.generate_new_key()?; } - + Ok(manager) } - + fn load_keys_from_db(&mut self) -> Result<()> { let db_keys = { let db = self.database.lock().unwrap(); db.get_all_rsa_keys()? }; - + for db_key in db_keys { let private_key = RsaPrivateKey::from_pkcs8_pem(&db_key.private_key_pem)?; let public_key = RsaPublicKey::from_public_key_pem(&db_key.public_key_pem)?; - + let encoding_key = EncodingKey::from_rsa_pem(db_key.private_key_pem.as_bytes())?; let decoding_key = DecodingKey::from_rsa_pem(db_key.public_key_pem.as_bytes())?; - + let key_pair = KeyPair { kid: db_key.kid.clone(), private_key, @@ -86,14 +86,14 @@ impl KeyManager { encoding_key, decoding_key, }; - + self.keys.insert(db_key.kid.clone(), key_pair); - + if db_key.is_current { self.current_key_id = Some(db_key.kid); } } - + Ok(()) } @@ -121,7 +121,7 @@ impl KeyManager { created_at: now, is_current: true, // This will be the new current key }; - + { let db = self.database.lock().unwrap(); db.create_rsa_key(&db_key)?; diff --git a/src/main.rs b/src/main.rs index f5951e0..0612bfa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,31 +1,36 @@ -use sts::http::Server; -use sts::Config; use std::thread; use std::time::Duration; +use sts::Config; +use sts::http::Server; fn main() { let config = Config::from_env(); let server = Server::new(config.clone()).expect("Failed to create server"); - + // Start cleanup task in background let cleanup_config = config.clone(); thread::spawn(move || { loop { - thread::sleep(Duration::from_secs(cleanup_config.cleanup_interval_hours as u64 * 3600)); + thread::sleep(Duration::from_secs( + cleanup_config.cleanup_interval_hours as u64 * 3600, + )); // Note: In the current implementation, we don't have direct access to the OAuth server // from here to call cleanup_expired_data(). In a production implementation, // you'd want to structure this differently or use a background job queue. } }); - + println!("Starting OAuth2 STS server..."); println!("Configuration:"); println!(" Bind Address: {}", config.bind_addr); println!(" Issuer URL: {}", config.issuer_url); println!(" Database: {}", config.database_path); - println!(" Rate Limit: {} requests/minute", config.rate_limit_requests_per_minute); + println!( + " Rate Limit: {} requests/minute", + config.rate_limit_requests_per_minute + ); println!(" Audit Logging: {}", config.enable_audit_logging); - + server.start(); } @@ -102,15 +107,15 @@ mod disabled_tests { fn test_jwks_endpoint() { let config = sts::Config::from_env(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); - + let jwks_json = oauth_server.get_jwks(); assert!(!jwks_json.is_empty()); - + // Parse the JSON to verify structure let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); assert!(jwks["keys"].is_array()); assert!(jwks["keys"].as_array().unwrap().len() > 0); - + let key = &jwks["keys"][0]; assert_eq!(key["kty"], "RSA"); assert_eq!(key["use"], "sig"); @@ -133,7 +138,7 @@ mod disabled_tests { auth_params.insert("state".to_string(), "test_state".to_string()); let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); - + // Extract the authorization code from redirect URL let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url @@ -150,17 +155,17 @@ mod disabled_tests { token_params.insert("client_secret".to_string(), "test_secret".to_string()); let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed"); - + // Parse token response let token_response: serde_json::Value = serde_json::from_str(&token_result) .expect("Invalid token response JSON"); - + assert_eq!(token_response["token_type"], "Bearer"); assert_eq!(token_response["expires_in"], 3600); assert!(token_response["access_token"].is_string()); - + let access_token = token_response["access_token"].as_str().unwrap(); - + // Step 3: Verify the JWT token has RSA signature and key ID let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256); @@ -168,7 +173,7 @@ mod disabled_tests { assert!(!header.kid.as_ref().unwrap().is_empty()); } - #[test] + #[test] fn test_token_validation_with_jwks() { let config = sts::Config::from_env(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); @@ -201,11 +206,11 @@ mod disabled_tests { // Get the JWKS let jwks_json = oauth_server.get_jwks(); let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON"); - + // Decode the token header to get the key ID let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header"); let kid = header.kid.as_ref().expect("No key ID in token"); - + // Find the matching key in JWKS let matching_key = jwks["keys"] .as_array() @@ -213,7 +218,7 @@ mod disabled_tests { .iter() .find(|key| key["kid"] == *kid) .expect("Key ID not found in JWKS"); - + assert_eq!(matching_key["kty"], "RSA"); assert_eq!(matching_key["alg"], "RS256"); } @@ -255,17 +260,17 @@ mod disabled_tests { &jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256) ); - + // Since we can't validate with a dummy key, we'll just verify the structure // by decoding the payload manually let parts: Vec<&str> = access_token.split('.').collect(); assert_eq!(parts.len(), 3); // header.payload.signature - + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD .decode(parts[1]) .expect("Failed to decode payload"); let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON"); - + assert!(claims["sub"].is_string()); assert!(claims["iss"].is_string()); assert!(claims["aud"].is_string()); diff --git a/src/migrations.rs b/src/migrations.rs index c7cd6bf..5076a9e 100644 --- a/src/migrations.rs +++ b/src/migrations.rs @@ -43,13 +43,16 @@ impl<'a> MigrationRunner<'a> { // Get current migration version let current_version = self.get_current_version()?; - + println!("Current database version: {}", current_version); // Run pending migrations for migration in MIGRATIONS { if migration.version > current_version { - println!("Running migration {}: {}", migration.version, migration.name); + println!( + "Running migration {}: {}", + migration.version, migration.name + ); self.run_migration(migration)?; } } @@ -70,7 +73,7 @@ impl<'a> MigrationRunner<'a> { fn run_migration(&self, migration: &Migration) -> Result<()> { // Execute the migration SQL self.conn.execute_batch(migration.sql)?; - + // Record the migration as applied self.conn.execute( "INSERT INTO schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)", @@ -86,14 +89,14 @@ impl<'a> MigrationRunner<'a> { pub fn rollback_to_version(&self, target_version: i32) -> Result<()> { println!("Rolling back to version {}", target_version); - + // This is a simplified rollback - in practice you'd need down migrations // For now, just remove migration records self.conn.execute( "DELETE FROM schema_migrations WHERE version > ?1", [target_version], )?; - + println!("Rollback completed (Note: This doesn't actually undo schema changes)"); Ok(()) } @@ -101,11 +104,11 @@ impl<'a> MigrationRunner<'a> { pub fn show_migration_status(&self) -> Result<()> { println!("Migration Status:"); println!("================"); - - let mut stmt = self.conn.prepare( - "SELECT version, name, applied_at FROM schema_migrations ORDER BY version" - )?; - + + let mut stmt = self + .conn + .prepare("SELECT version, name, applied_at FROM schema_migrations ORDER BY version")?; + let migrations = stmt.query_map([], |row| { Ok(( row.get::<_, i32>(0)?, @@ -116,14 +119,20 @@ impl<'a> MigrationRunner<'a> { for migration in migrations { let (version, name, applied_at) = migration?; - println!("✅ Migration {}: {} (applied: {})", version, name, applied_at); + println!( + "✅ Migration {}: {} (applied: {})", + version, name, applied_at + ); } // Show pending migrations let current_version = self.get_current_version()?; for migration in MIGRATIONS { if migration.version > current_version { - println!("⏳ Migration {}: {} (pending)", migration.version, migration.name); + println!( + "⏳ Migration {}: {} (pending)", + migration.version, migration.name + ); } } @@ -139,14 +148,17 @@ mod tests { fn test_migration_runner() { let conn = Connection::open_in_memory().unwrap(); let runner = MigrationRunner::new(&conn); - + // Should start with version 0 assert_eq!(runner.get_current_version().unwrap(), 0); - + // Run migrations runner.run_migrations().unwrap(); - + // Should now be at latest version - assert_eq!(runner.get_current_version().unwrap(), MIGRATIONS.len() as i32); + assert_eq!( + runner.get_current_version().unwrap(), + MIGRATIONS.len() as i32 + ); } -}
\ No newline at end of file +} |
