use anyhow::Result; use chrono::{DateTime, Utc}; use rusqlite::{Connection, params}; use serde::{Deserialize, Serialize}; use std::path::Path; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbOAuthClient { pub id: i64, pub client_id: String, pub client_secret_hash: String, pub client_name: String, pub redirect_uris: String, // JSON array pub scopes: String, // Space-separated pub grant_types: String, // Space-separated pub response_types: String, // Space-separated pub created_at: DateTime, pub updated_at: DateTime, pub is_active: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbAuthCode { pub id: i64, pub code: String, pub client_id: String, pub user_id: String, pub redirect_uri: String, pub scope: Option, pub expires_at: DateTime, pub created_at: DateTime, pub is_used: bool, // PKCE fields pub code_challenge: Option, pub code_challenge_method: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbAccessToken { pub id: i64, pub token_id: String, pub client_id: String, pub user_id: String, pub scope: Option, pub expires_at: DateTime, pub created_at: DateTime, pub is_revoked: bool, pub token_hash: String, // For revocation lookup } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbRefreshToken { pub id: i64, pub token_id: String, pub access_token_id: i64, pub client_id: String, pub user_id: String, pub scope: Option, pub expires_at: DateTime, pub created_at: DateTime, pub is_revoked: bool, pub token_hash: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbRsaKey { pub id: i64, pub kid: String, pub private_key_pem: String, pub public_key_pem: String, pub created_at: DateTime, pub is_current: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbAuditLog { pub id: i64, pub event_type: String, pub client_id: Option, pub user_id: Option, pub ip_address: Option, pub user_agent: Option, pub details: Option, // JSON pub created_at: DateTime, pub success: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbRateLimit { pub id: i64, pub identifier: String, // client_id or IP address pub endpoint: String, pub count: i32, pub window_start: DateTime, pub created_at: DateTime, } pub struct Database { conn: Connection, } impl Database { pub fn new>(path: P) -> Result { let conn = Connection::open(path)?; let db = Self { conn }; db.run_migrations()?; Ok(db) } pub fn new_in_memory() -> Result { let conn = Connection::open_in_memory()?; let db = Self { conn }; db.run_migrations()?; Ok(db) } fn run_migrations(&self) -> Result<()> { // Use the migration system instead of duplicated schema let migration_runner = crate::migrations::MigrationRunner::new(&self.conn)?; migration_runner.run_migrations()?; Ok(()) } // OAuth Client operations pub fn create_oauth_client(&self, client: &DbOAuthClient) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO oauth_clients (client_id, client_secret_hash, client_name, redirect_uris, scopes, grant_types, response_types, created_at, updated_at, is_active) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", )?; let id = stmt.insert(params![ client.client_id, client.client_secret_hash, client.client_name, client.redirect_uris, client.scopes, client.grant_types, client.response_types, client.created_at.to_rfc3339(), client.updated_at.to_rfc3339(), client.is_active ])?; Ok(id) } pub fn get_oauth_client(&self, client_id: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, client_id, client_secret_hash, client_name, redirect_uris, scopes, grant_types, response_types, created_at, updated_at, is_active FROM oauth_clients WHERE client_id = ?1 AND is_active = 1", )?; let client = stmt.query_row([client_id], |row| { Ok(DbOAuthClient { id: row.get(0)?, client_id: row.get(1)?, client_secret_hash: row.get(2)?, client_name: row.get(3)?, redirect_uris: row.get(4)?, scopes: row.get(5)?, grant_types: row.get(6)?, response_types: row.get(7)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 8, "created_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 9, "updated_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), is_active: row.get(10)?, }) }); match client { Ok(client) => Ok(Some(client)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } // Authorization Code operations pub fn create_auth_code(&self, auth_code: &DbAuthCode) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO auth_codes (code, client_id, user_id, redirect_uri, scope, expires_at, created_at, is_used, code_challenge, code_challenge_method) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)", )?; let id = stmt.insert(params![ auth_code.code, auth_code.client_id, auth_code.user_id, auth_code.redirect_uri, auth_code.scope, auth_code.expires_at.to_rfc3339(), auth_code.created_at.to_rfc3339(), auth_code.is_used, auth_code.code_challenge, auth_code.code_challenge_method ])?; Ok(id) } pub fn get_auth_code(&self, code: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, code, client_id, user_id, redirect_uri, scope, expires_at, created_at, is_used, code_challenge, code_challenge_method FROM auth_codes WHERE code = ?1", )?; let auth_code = stmt.query_row([code], |row| { Ok(DbAuthCode { id: row.get(0)?, code: row.get(1)?, client_id: row.get(2)?, user_id: row.get(3)?, redirect_uri: row.get(4)?, scope: row.get(5)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 6, "expires_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 7, "created_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), is_used: row.get(8)?, code_challenge: row.get(9)?, code_challenge_method: row.get(10)?, }) }); match auth_code { Ok(code) => Ok(Some(code)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } pub fn mark_auth_code_used(&self, code: &str) -> Result<()> { self.conn .execute("UPDATE auth_codes SET is_used = 1 WHERE code = ?1", [code])?; Ok(()) } // Access Token operations pub fn create_access_token(&self, token: &DbAccessToken) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO access_tokens (token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", )?; let id = stmt.insert(params![ token.token_id, token.client_id, token.user_id, token.scope, token.expires_at.to_rfc3339(), token.created_at.to_rfc3339(), token.is_revoked, token.token_hash ])?; Ok(id) } pub fn get_access_token(&self, token_hash: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash FROM access_tokens WHERE token_hash = ?1" )?; let token = stmt.query_row([token_hash], |row| { Ok(DbAccessToken { id: row.get(0)?, token_id: row.get(1)?, client_id: row.get(2)?, user_id: row.get(3)?, scope: row.get(4)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 5, "expires_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 6, "created_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), is_revoked: row.get(7)?, token_hash: row.get(8)?, }) }); match token { Ok(token) => Ok(Some(token)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } pub fn revoke_access_token(&self, token_hash: &str) -> Result<()> { self.conn.execute( "UPDATE access_tokens SET is_revoked = 1 WHERE token_hash = ?1", [token_hash], )?; Ok(()) } // Refresh Token operations pub fn create_refresh_token(&self, token: &DbRefreshToken) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO refresh_tokens (token_id, access_token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", )?; let id = stmt.insert(params![ token.token_id, token.access_token_id, token.client_id, token.user_id, token.scope, token.expires_at.to_rfc3339(), token.created_at.to_rfc3339(), token.is_revoked, token.token_hash ])?; Ok(id) } pub fn get_refresh_token(&self, token_hash: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash FROM refresh_tokens WHERE token_hash = ?1" )?; let token = stmt.query_row([token_hash], |row| { Ok(DbRefreshToken { id: row.get(0)?, token_id: row.get(1)?, access_token_id: row.get(2)?, client_id: row.get(3)?, user_id: row.get(4)?, scope: row.get(5)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 6, "expires_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 7, "created_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), is_revoked: row.get(8)?, token_hash: row.get(9)?, }) }); match token { Ok(token) => Ok(Some(token)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } pub fn revoke_refresh_token(&self, token_hash: &str) -> Result<()> { self.conn.execute( "UPDATE refresh_tokens SET is_revoked = 1 WHERE token_hash = ?1", [token_hash], )?; Ok(()) } pub fn get_refresh_token_by_access_token(&self, access_token_id: i64) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash FROM refresh_tokens WHERE access_token_id = ?1 AND is_revoked = 0" )?; let token = stmt.query_row([access_token_id], |row| { Ok(DbRefreshToken { id: row.get(0)?, token_id: row.get(1)?, access_token_id: row.get(2)?, client_id: row.get(3)?, user_id: row.get(4)?, scope: row.get(5)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 6, "expires_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 7, "created_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), is_revoked: row.get(8)?, token_hash: row.get(9)?, }) }); match token { Ok(token) => Ok(Some(token)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } // RSA Key operations pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO rsa_keys (kid, private_key_pem, public_key_pem, created_at, is_current) VALUES (?1, ?2, ?3, ?4, ?5)", )?; let id = stmt.insert(params![ key.kid, key.private_key_pem, key.public_key_pem, key.created_at.to_rfc3339(), key.is_current ])?; Ok(id) } pub fn get_current_rsa_key(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1", )?; let key = stmt.query_row([], |row| { Ok(DbRsaKey { id: row.get(0)?, kid: row.get(1)?, private_key_pem: row.get(2)?, public_key_pem: row.get(3)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 4, "created_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), is_current: row.get(5)?, }) }); match key { Ok(key) => Ok(Some(key)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } pub fn get_all_rsa_keys(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current FROM rsa_keys ORDER BY created_at DESC", )?; let keys = stmt.query_map([], |row| { Ok(DbRsaKey { id: row.get(0)?, kid: row.get(1)?, private_key_pem: row.get(2)?, public_key_pem: row.get(3)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) .map_err(|_| { rusqlite::Error::InvalidColumnType( 4, "created_at".to_string(), rusqlite::types::Type::Text, ) })? .with_timezone(&Utc), is_current: row.get(5)?, }) })?; let mut result = Vec::new(); for key in keys { result.push(key?); } Ok(result) } pub fn set_current_rsa_key(&self, kid: &str) -> Result<()> { // First, unset all current keys self.conn .execute("UPDATE rsa_keys SET is_current = 0", [])?; // Then set the specified key as current self.conn .execute("UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1", [kid])?; Ok(()) } // Audit Log operations pub fn create_audit_log(&self, log: &DbAuditLog) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO audit_logs (event_type, client_id, user_id, ip_address, user_agent, details, created_at, success) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", )?; let id = stmt.insert(params![ log.event_type, log.client_id, log.user_id, log.ip_address, log.user_agent, log.details, log.created_at.to_rfc3339(), log.success ])?; Ok(id) } // Rate Limiting operations pub fn increment_rate_limit( &self, identifier: &str, endpoint: &str, window_minutes: i32, ) -> Result { let now = Utc::now(); let window_start = now - chrono::Duration::minutes(window_minutes as i64); // Try to increment existing counter in current window let affected = self.conn.execute( "UPDATE rate_limits SET count = count + 1 WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3", params![identifier, endpoint, window_start.to_rfc3339()], )?; if affected == 0 { // No existing record, create new one self.conn.execute( "INSERT OR REPLACE INTO rate_limits (identifier, endpoint, count, window_start, created_at) VALUES (?1, ?2, 1, ?3, ?4)", params![identifier, endpoint, now.to_rfc3339(), now.to_rfc3339()], )?; Ok(1) } else { // Return current count let count: i32 = self.conn.query_row( "SELECT count FROM rate_limits WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3", params![identifier, endpoint, window_start.to_rfc3339()], |row| row.get(0), )?; Ok(count) } } // Cleanup operations pub fn cleanup_expired_codes(&self) -> Result { let now = Utc::now(); let affected = self.conn.execute( "DELETE FROM auth_codes WHERE expires_at < ?1", [now.to_rfc3339()], )?; Ok(affected) } pub fn cleanup_expired_tokens(&self) -> Result { let now = Utc::now(); let affected = self.conn.execute( "DELETE FROM access_tokens WHERE expires_at < ?1 AND is_revoked = 1", [now.to_rfc3339()], )?; Ok(affected) } pub fn cleanup_old_audit_logs(&self, days: i32) -> Result { let cutoff = Utc::now() - chrono::Duration::days(days as i64); let affected = self.conn.execute( "DELETE FROM audit_logs WHERE created_at < ?1", [cutoff.to_rfc3339()], )?; Ok(affected) } // Additional methods needed for repository patterns pub fn update_oauth_client(&self, client: &DbOAuthClient) -> Result<()> { self.conn.execute( "UPDATE oauth_clients SET client_secret_hash = ?2, client_name = ?3, redirect_uris = ?4, scopes = ?5, grant_types = ?6, response_types = ?7, updated_at = ?8, is_active = ?9 WHERE client_id = ?1", params![ client.client_id, client.client_secret_hash, client.client_name, client.redirect_uris, client.scopes, client.grant_types, client.response_types, client.updated_at.to_rfc3339(), client.is_active ], )?; Ok(()) } pub fn delete_oauth_client(&self, client_id: &str) -> Result<()> { self.conn.execute( "DELETE FROM oauth_clients WHERE client_id = ?1", [client_id], )?; Ok(()) } pub fn list_oauth_clients(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, client_id, client_secret_hash, client_name, redirect_uris, scopes, grant_types, response_types, created_at, updated_at, is_active FROM oauth_clients ORDER BY created_at DESC", )?; let clients = stmt .query_map([], |row| { Ok(DbOAuthClient { id: row.get(0)?, client_id: row.get(1)?, client_secret_hash: row.get(2)?, client_name: row.get(3)?, redirect_uris: row.get(4)?, scopes: row.get(5)?, grant_types: row.get(6)?, response_types: row.get(7)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?) .map_err(|e| { rusqlite::Error::FromSqlConversionFailure( 8, rusqlite::types::Type::Text, Box::new(e), ) })? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?) .map_err(|e| { rusqlite::Error::FromSqlConversionFailure( 9, rusqlite::types::Type::Text, Box::new(e), ) })? .with_timezone(&Utc), is_active: row.get(10)?, }) })? .collect::, _>>()?; Ok(clients) } pub fn get_audit_logs(&self, limit: i32) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, event_type, client_id, user_id, ip_address, user_agent, details, created_at, success FROM audit_logs ORDER BY created_at DESC LIMIT ?1" )?; let logs = stmt .query_map([limit], |row| { Ok(DbAuditLog { id: row.get(0)?, event_type: row.get(1)?, client_id: row.get(2)?, user_id: row.get(3)?, ip_address: row.get(4)?, user_agent: row.get(5)?, details: row.get(6)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) .map_err(|e| { rusqlite::Error::FromSqlConversionFailure( 7, rusqlite::types::Type::Text, Box::new(e), ) })? .with_timezone(&Utc), success: row.get(8)?, }) })? .collect::, _>>()?; Ok(logs) } pub fn cleanup_old_rate_limits(&self) -> Result<()> { let cutoff = Utc::now() - chrono::Duration::hours(24); // Clean up rate limits older than 24 hours self.conn.execute( "DELETE FROM rate_limits WHERE created_at < ?1", [cutoff.to_rfc3339()], )?; Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_database_creation() { let _db = Database::new_in_memory().expect("Failed to create database"); assert!(true); // If we got here, database was created successfully } #[test] fn test_oauth_client_operations() { let db = Database::new_in_memory().expect("Failed to create database"); let client = DbOAuthClient { id: 0, client_id: "test_client".to_string(), client_secret_hash: "hash123".to_string(), client_name: "Test Client".to_string(), redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(), scopes: "openid profile".to_string(), grant_types: "authorization_code".to_string(), response_types: "code".to_string(), created_at: Utc::now(), updated_at: Utc::now(), is_active: true, }; let id = db .create_oauth_client(&client) .expect("Failed to create client"); assert!(id > 0); let retrieved = db .get_oauth_client("test_client") .expect("Failed to get client"); assert!(retrieved.is_some()); assert_eq!(retrieved.unwrap().client_name, "Test Client"); } #[test] fn test_auth_code_operations() { let db = Database::new_in_memory().expect("Failed to create database"); // First create a client (required for foreign key constraint) let client = DbOAuthClient { id: 0, client_id: "test_client".to_string(), client_secret_hash: "hash123".to_string(), client_name: "Test Client".to_string(), redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(), scopes: "openid profile".to_string(), grant_types: "authorization_code".to_string(), response_types: "code".to_string(), created_at: Utc::now(), updated_at: Utc::now(), is_active: true, }; db.create_oauth_client(&client) .expect("Failed to create client"); let auth_code = DbAuthCode { id: 0, code: "test_code_123".to_string(), client_id: "test_client".to_string(), user_id: "test_user".to_string(), redirect_uri: "http://localhost:3000/callback".to_string(), scope: Some("openid".to_string()), expires_at: Utc::now() + chrono::Duration::minutes(10), created_at: Utc::now(), is_used: false, code_challenge: Some("challenge123".to_string()), code_challenge_method: Some("S256".to_string()), }; let id = db .create_auth_code(&auth_code) .expect("Failed to create auth code"); assert!(id > 0); let retrieved = db .get_auth_code("test_code_123") .expect("Failed to get auth code"); assert!(retrieved.is_some()); let code = retrieved.unwrap(); assert_eq!(code.client_id, "test_client"); assert_eq!(code.is_used, false); db.mark_auth_code_used("test_code_123") .expect("Failed to mark code as used"); let updated = db .get_auth_code("test_code_123") .expect("Failed to get auth code"); assert_eq!(updated.unwrap().is_used, true); } }