use anyhow::Result; use chrono::{DateTime, Utc}; use rusqlite::{params, Connection}; use serde::{Deserialize, Serialize}; use std::path::Path; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbOAuthClient { pub id: i64, pub client_id: String, pub client_secret_hash: String, pub client_name: String, pub redirect_uris: String, // JSON array pub scopes: String, // Space-separated pub grant_types: String, // Space-separated pub response_types: String, // Space-separated pub created_at: DateTime, pub updated_at: DateTime, pub is_active: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbAuthCode { pub id: i64, pub code: String, pub client_id: String, pub user_id: String, pub redirect_uri: String, pub scope: Option, pub expires_at: DateTime, pub created_at: DateTime, pub is_used: bool, // PKCE fields pub code_challenge: Option, pub code_challenge_method: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbAccessToken { pub id: i64, pub token_id: String, pub client_id: String, pub user_id: String, pub scope: Option, pub expires_at: DateTime, pub created_at: DateTime, pub is_revoked: bool, pub token_hash: String, // For revocation lookup } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbRefreshToken { pub id: i64, pub token_id: String, pub access_token_id: i64, pub client_id: String, pub user_id: String, pub scope: Option, pub expires_at: DateTime, pub created_at: DateTime, pub is_revoked: bool, pub token_hash: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbRsaKey { pub id: i64, pub kid: String, pub private_key_pem: String, pub public_key_pem: String, pub created_at: DateTime, pub is_current: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbAuditLog { pub id: i64, pub event_type: String, pub client_id: Option, pub user_id: Option, pub ip_address: Option, pub user_agent: Option, pub details: Option, // JSON pub created_at: DateTime, pub success: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DbRateLimit { pub id: i64, pub identifier: String, // client_id or IP address pub endpoint: String, pub count: i32, pub window_start: DateTime, pub created_at: DateTime, } pub struct Database { conn: Connection, } impl Database { pub fn new>(path: P) -> Result { let conn = Connection::open(path)?; let db = Self { conn }; db.initialize_schema()?; Ok(db) } pub fn new_in_memory() -> Result { let conn = Connection::open_in_memory()?; let db = Self { conn }; db.initialize_schema()?; Ok(db) } fn initialize_schema(&self) -> Result<()> { // OAuth Clients table self.conn.execute( "CREATE TABLE IF NOT EXISTS oauth_clients ( id INTEGER PRIMARY KEY AUTOINCREMENT, client_id TEXT NOT NULL UNIQUE, client_secret_hash TEXT NOT NULL, client_name TEXT NOT NULL, redirect_uris TEXT NOT NULL, -- JSON array scopes TEXT NOT NULL, -- space-separated grant_types TEXT NOT NULL, -- space-separated response_types TEXT NOT NULL, -- space-separated created_at TEXT NOT NULL, updated_at TEXT NOT NULL, is_active BOOLEAN NOT NULL DEFAULT 1 )", [], )?; // Authorization Codes table self.conn.execute( "CREATE TABLE IF NOT EXISTS auth_codes ( id INTEGER PRIMARY KEY AUTOINCREMENT, code TEXT NOT NULL UNIQUE, client_id TEXT NOT NULL, user_id TEXT NOT NULL, redirect_uri TEXT NOT NULL, scope TEXT, expires_at TEXT NOT NULL, created_at TEXT NOT NULL, is_used BOOLEAN NOT NULL DEFAULT 0, code_challenge TEXT, code_challenge_method TEXT, FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id) )", [], )?; // Access Tokens table self.conn.execute( "CREATE TABLE IF NOT EXISTS access_tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, token_id TEXT NOT NULL UNIQUE, client_id TEXT NOT NULL, user_id TEXT NOT NULL, scope TEXT, expires_at TEXT NOT NULL, created_at TEXT NOT NULL, is_revoked BOOLEAN NOT NULL DEFAULT 0, token_hash TEXT NOT NULL, FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id) )", [], )?; // Refresh Tokens table self.conn.execute( "CREATE TABLE IF NOT EXISTS refresh_tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, token_id TEXT NOT NULL UNIQUE, access_token_id INTEGER NOT NULL, client_id TEXT NOT NULL, user_id TEXT NOT NULL, scope TEXT, expires_at TEXT NOT NULL, created_at TEXT NOT NULL, is_revoked BOOLEAN NOT NULL DEFAULT 0, token_hash TEXT NOT NULL, FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id), FOREIGN KEY (access_token_id) REFERENCES access_tokens (id) )", [], )?; // RSA Keys table self.conn.execute( "CREATE TABLE IF NOT EXISTS rsa_keys ( id INTEGER PRIMARY KEY AUTOINCREMENT, kid TEXT NOT NULL UNIQUE, private_key_pem TEXT NOT NULL, public_key_pem TEXT NOT NULL, created_at TEXT NOT NULL, is_current BOOLEAN NOT NULL DEFAULT 0 )", [], )?; // Audit Log table self.conn.execute( "CREATE TABLE IF NOT EXISTS audit_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, event_type TEXT NOT NULL, client_id TEXT, user_id TEXT, ip_address TEXT, user_agent TEXT, details TEXT, -- JSON created_at TEXT NOT NULL, success BOOLEAN NOT NULL )", [], )?; // Rate Limiting table self.conn.execute( "CREATE TABLE IF NOT EXISTS rate_limits ( id INTEGER PRIMARY KEY AUTOINCREMENT, identifier TEXT NOT NULL, -- client_id or IP endpoint TEXT NOT NULL, count INTEGER NOT NULL DEFAULT 1, window_start TEXT NOT NULL, created_at TEXT NOT NULL, UNIQUE (identifier, endpoint, window_start) )", [], )?; // Create indexes for performance self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_auth_codes_client_id ON auth_codes (client_id)", [], )?; self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_auth_codes_expires_at ON auth_codes (expires_at)", [], )?; self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_access_tokens_client_id ON access_tokens (client_id)", [], )?; self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_access_tokens_expires_at ON access_tokens (expires_at)", [], )?; self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_refresh_tokens_client_id ON refresh_tokens (client_id)", [], )?; self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs (created_at)", [], )?; self.conn.execute( "CREATE INDEX IF NOT EXISTS idx_rate_limits_identifier ON rate_limits (identifier, endpoint)", [], )?; Ok(()) } // OAuth Client operations pub fn create_oauth_client(&self, client: &DbOAuthClient) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO oauth_clients (client_id, client_secret_hash, client_name, redirect_uris, scopes, grant_types, response_types, created_at, updated_at, is_active) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)" )?; let id = stmt.insert(params![ client.client_id, client.client_secret_hash, client.client_name, client.redirect_uris, client.scopes, client.grant_types, client.response_types, client.created_at.to_rfc3339(), client.updated_at.to_rfc3339(), client.is_active ])?; Ok(id) } pub fn get_oauth_client(&self, client_id: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, client_id, client_secret_hash, client_name, redirect_uris, scopes, grant_types, response_types, created_at, updated_at, is_active FROM oauth_clients WHERE client_id = ?1 AND is_active = 1" )?; let client = stmt.query_row([client_id], |row| { Ok(DbOAuthClient { id: row.get(0)?, client_id: row.get(1)?, client_secret_hash: row.get(2)?, client_name: row.get(3)?, redirect_uris: row.get(4)?, scopes: row.get(5)?, grant_types: row.get(6)?, response_types: row.get(7)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?) .map_err(|_| rusqlite::Error::InvalidColumnType(8, "created_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?) .map_err(|_| rusqlite::Error::InvalidColumnType(9, "updated_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), is_active: row.get(10)?, }) }); match client { Ok(client) => Ok(Some(client)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } // Authorization Code operations pub fn create_auth_code(&self, auth_code: &DbAuthCode) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO auth_codes (code, client_id, user_id, redirect_uri, scope, expires_at, created_at, is_used, code_challenge, code_challenge_method) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)" )?; let id = stmt.insert(params![ auth_code.code, auth_code.client_id, auth_code.user_id, auth_code.redirect_uri, auth_code.scope, auth_code.expires_at.to_rfc3339(), auth_code.created_at.to_rfc3339(), auth_code.is_used, auth_code.code_challenge, auth_code.code_challenge_method ])?; Ok(id) } pub fn get_auth_code(&self, code: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, code, client_id, user_id, redirect_uri, scope, expires_at, created_at, is_used, code_challenge, code_challenge_method FROM auth_codes WHERE code = ?1" )?; let auth_code = stmt.query_row([code], |row| { Ok(DbAuthCode { id: row.get(0)?, code: row.get(1)?, client_id: row.get(2)?, user_id: row.get(3)?, redirect_uri: row.get(4)?, scope: row.get(5)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) .map_err(|_| rusqlite::Error::InvalidColumnType(6, "expires_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) .map_err(|_| rusqlite::Error::InvalidColumnType(7, "created_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), is_used: row.get(8)?, code_challenge: row.get(9)?, code_challenge_method: row.get(10)?, }) }); match auth_code { Ok(code) => Ok(Some(code)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } pub fn mark_auth_code_used(&self, code: &str) -> Result<()> { self.conn.execute( "UPDATE auth_codes SET is_used = 1 WHERE code = ?1", [code], )?; Ok(()) } // Access Token operations pub fn create_access_token(&self, token: &DbAccessToken) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO access_tokens (token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" )?; let id = stmt.insert(params![ token.token_id, token.client_id, token.user_id, token.scope, token.expires_at.to_rfc3339(), token.created_at.to_rfc3339(), token.is_revoked, token.token_hash ])?; Ok(id) } pub fn get_access_token(&self, token_hash: &str) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash FROM access_tokens WHERE token_hash = ?1" )?; let token = stmt.query_row([token_hash], |row| { Ok(DbAccessToken { id: row.get(0)?, token_id: row.get(1)?, client_id: row.get(2)?, user_id: row.get(3)?, scope: row.get(4)?, expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?) .map_err(|_| rusqlite::Error::InvalidColumnType(5, "expires_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) .map_err(|_| rusqlite::Error::InvalidColumnType(6, "created_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), is_revoked: row.get(7)?, token_hash: row.get(8)?, }) }); match token { Ok(token) => Ok(Some(token)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } pub fn revoke_access_token(&self, token_hash: &str) -> Result<()> { self.conn.execute( "UPDATE access_tokens SET is_revoked = 1 WHERE token_hash = ?1", [token_hash], )?; Ok(()) } // RSA Key operations pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO rsa_keys (kid, private_key_pem, public_key_pem, created_at, is_current) VALUES (?1, ?2, ?3, ?4, ?5)" )?; let id = stmt.insert(params![ key.kid, key.private_key_pem, key.public_key_pem, key.created_at.to_rfc3339(), key.is_current ])?; Ok(id) } pub fn get_current_rsa_key(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1" )?; let key = stmt.query_row([], |row| { Ok(DbRsaKey { id: row.get(0)?, kid: row.get(1)?, private_key_pem: row.get(2)?, public_key_pem: row.get(3)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), is_current: row.get(5)?, }) }); match key { Ok(key) => Ok(Some(key)), Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), Err(e) => Err(e.into()), } } pub fn get_all_rsa_keys(&self) -> Result> { let mut stmt = self.conn.prepare( "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current FROM rsa_keys ORDER BY created_at DESC" )?; let keys = stmt.query_map([], |row| { Ok(DbRsaKey { id: row.get(0)?, kid: row.get(1)?, private_key_pem: row.get(2)?, public_key_pem: row.get(3)?, created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))? .with_timezone(&Utc), is_current: row.get(5)?, }) })?; let mut result = Vec::new(); for key in keys { result.push(key?); } Ok(result) } pub fn set_current_rsa_key(&self, kid: &str) -> Result<()> { // First, unset all current keys self.conn.execute("UPDATE rsa_keys SET is_current = 0", [])?; // Then set the specified key as current self.conn.execute( "UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1", [kid], )?; Ok(()) } // Audit Log operations pub fn create_audit_log(&self, log: &DbAuditLog) -> Result { let mut stmt = self.conn.prepare( "INSERT INTO audit_logs (event_type, client_id, user_id, ip_address, user_agent, details, created_at, success) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" )?; let id = stmt.insert(params![ log.event_type, log.client_id, log.user_id, log.ip_address, log.user_agent, log.details, log.created_at.to_rfc3339(), log.success ])?; Ok(id) } // Rate Limiting operations pub fn increment_rate_limit(&self, identifier: &str, endpoint: &str, window_minutes: i32) -> Result { let now = Utc::now(); let window_start = now - chrono::Duration::minutes(window_minutes as i64); // Try to increment existing counter in current window let affected = self.conn.execute( "UPDATE rate_limits SET count = count + 1 WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3", params![identifier, endpoint, window_start.to_rfc3339()], )?; if affected == 0 { // No existing record, create new one self.conn.execute( "INSERT OR REPLACE INTO rate_limits (identifier, endpoint, count, window_start, created_at) VALUES (?1, ?2, 1, ?3, ?4)", params![identifier, endpoint, now.to_rfc3339(), now.to_rfc3339()], )?; Ok(1) } else { // Return current count let count: i32 = self.conn.query_row( "SELECT count FROM rate_limits WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3", params![identifier, endpoint, window_start.to_rfc3339()], |row| row.get(0), )?; Ok(count) } } // Cleanup operations pub fn cleanup_expired_codes(&self) -> Result { let now = Utc::now(); let affected = self.conn.execute( "DELETE FROM auth_codes WHERE expires_at < ?1", [now.to_rfc3339()], )?; Ok(affected) } pub fn cleanup_expired_tokens(&self) -> Result { let now = Utc::now(); let affected = self.conn.execute( "DELETE FROM access_tokens WHERE expires_at < ?1 AND is_revoked = 1", [now.to_rfc3339()], )?; Ok(affected) } pub fn cleanup_old_audit_logs(&self, days: i32) -> Result { let cutoff = Utc::now() - chrono::Duration::days(days as i64); let affected = self.conn.execute( "DELETE FROM audit_logs WHERE created_at < ?1", [cutoff.to_rfc3339()], )?; Ok(affected) } } #[cfg(test)] mod tests { use super::*; #[test] fn test_database_creation() { let _db = Database::new_in_memory().expect("Failed to create database"); assert!(true); // If we got here, database was created successfully } #[test] fn test_oauth_client_operations() { let db = Database::new_in_memory().expect("Failed to create database"); let client = DbOAuthClient { id: 0, client_id: "test_client".to_string(), client_secret_hash: "hash123".to_string(), client_name: "Test Client".to_string(), redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(), scopes: "openid profile".to_string(), grant_types: "authorization_code".to_string(), response_types: "code".to_string(), created_at: Utc::now(), updated_at: Utc::now(), is_active: true, }; let id = db.create_oauth_client(&client).expect("Failed to create client"); assert!(id > 0); let retrieved = db.get_oauth_client("test_client").expect("Failed to get client"); assert!(retrieved.is_some()); assert_eq!(retrieved.unwrap().client_name, "Test Client"); } #[test] fn test_auth_code_operations() { let db = Database::new_in_memory().expect("Failed to create database"); // First create a client (required for foreign key constraint) let client = DbOAuthClient { id: 0, client_id: "test_client".to_string(), client_secret_hash: "hash123".to_string(), client_name: "Test Client".to_string(), redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(), scopes: "openid profile".to_string(), grant_types: "authorization_code".to_string(), response_types: "code".to_string(), created_at: Utc::now(), updated_at: Utc::now(), is_active: true, }; db.create_oauth_client(&client).expect("Failed to create client"); let auth_code = DbAuthCode { id: 0, code: "test_code_123".to_string(), client_id: "test_client".to_string(), user_id: "test_user".to_string(), redirect_uri: "http://localhost:3000/callback".to_string(), scope: Some("openid".to_string()), expires_at: Utc::now() + chrono::Duration::minutes(10), created_at: Utc::now(), is_used: false, code_challenge: Some("challenge123".to_string()), code_challenge_method: Some("S256".to_string()), }; let id = db.create_auth_code(&auth_code).expect("Failed to create auth code"); assert!(id > 0); let retrieved = db.get_auth_code("test_code_123").expect("Failed to get auth code"); assert!(retrieved.is_some()); let code = retrieved.unwrap(); assert_eq!(code.client_id, "test_client"); assert_eq!(code.is_used, false); db.mark_auth_code_used("test_code_123").expect("Failed to mark code as used"); let updated = db.get_auth_code("test_code_123").expect("Failed to get auth code"); assert_eq!(updated.unwrap().is_used, true); } }