diff options
Diffstat (limited to 'src/database.rs')
| -rw-r--r-- | src/database.rs | 703 |
1 files changed, 703 insertions, 0 deletions
diff --git a/src/database.rs b/src/database.rs new file mode 100644 index 0000000..dc33cf8 --- /dev/null +++ b/src/database.rs @@ -0,0 +1,703 @@ +use anyhow::Result; +use chrono::{DateTime, Utc}; +use rusqlite::{params, Connection}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbOAuthClient { + pub id: i64, + pub client_id: String, + pub client_secret_hash: String, + pub client_name: String, + pub redirect_uris: String, // JSON array + pub scopes: String, // Space-separated + pub grant_types: String, // Space-separated + pub response_types: String, // Space-separated + pub created_at: DateTime<Utc>, + pub updated_at: DateTime<Utc>, + pub is_active: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbAuthCode { + pub id: i64, + pub code: String, + pub client_id: String, + pub user_id: String, + pub redirect_uri: String, + pub scope: Option<String>, + pub expires_at: DateTime<Utc>, + pub created_at: DateTime<Utc>, + pub is_used: bool, + // PKCE fields + pub code_challenge: Option<String>, + pub code_challenge_method: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbAccessToken { + pub id: i64, + pub token_id: String, + pub client_id: String, + pub user_id: String, + pub scope: Option<String>, + pub expires_at: DateTime<Utc>, + pub created_at: DateTime<Utc>, + pub is_revoked: bool, + pub token_hash: String, // For revocation lookup +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbRefreshToken { + pub id: i64, + pub token_id: String, + pub access_token_id: i64, + pub client_id: String, + pub user_id: String, + pub scope: Option<String>, + pub expires_at: DateTime<Utc>, + pub created_at: DateTime<Utc>, + pub is_revoked: bool, + pub token_hash: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbRsaKey { + pub id: i64, + pub kid: String, + pub private_key_pem: String, + pub public_key_pem: String, + pub created_at: DateTime<Utc>, + pub is_current: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbAuditLog { + pub id: i64, + pub event_type: String, + pub client_id: Option<String>, + pub user_id: Option<String>, + pub ip_address: Option<String>, + pub user_agent: Option<String>, + pub details: Option<String>, // JSON + pub created_at: DateTime<Utc>, + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DbRateLimit { + pub id: i64, + pub identifier: String, // client_id or IP address + pub endpoint: String, + pub count: i32, + pub window_start: DateTime<Utc>, + pub created_at: DateTime<Utc>, +} + +pub struct Database { + conn: Connection, +} + +impl Database { + pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> { + let conn = Connection::open(path)?; + let db = Self { conn }; + db.initialize_schema()?; + Ok(db) + } + + pub fn new_in_memory() -> Result<Self> { + let conn = Connection::open_in_memory()?; + let db = Self { conn }; + db.initialize_schema()?; + Ok(db) + } + + fn initialize_schema(&self) -> Result<()> { + // OAuth Clients table + self.conn.execute( + "CREATE TABLE IF NOT EXISTS oauth_clients ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + client_id TEXT NOT NULL UNIQUE, + client_secret_hash TEXT NOT NULL, + client_name TEXT NOT NULL, + redirect_uris TEXT NOT NULL, -- JSON array + scopes TEXT NOT NULL, -- space-separated + grant_types TEXT NOT NULL, -- space-separated + response_types TEXT NOT NULL, -- space-separated + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + is_active BOOLEAN NOT NULL DEFAULT 1 + )", + [], + )?; + + // Authorization Codes table + self.conn.execute( + "CREATE TABLE IF NOT EXISTS auth_codes ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + code TEXT NOT NULL UNIQUE, + client_id TEXT NOT NULL, + user_id TEXT NOT NULL, + redirect_uri TEXT NOT NULL, + scope TEXT, + expires_at TEXT NOT NULL, + created_at TEXT NOT NULL, + is_used BOOLEAN NOT NULL DEFAULT 0, + code_challenge TEXT, + code_challenge_method TEXT, + FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id) + )", + [], + )?; + + // Access Tokens table + self.conn.execute( + "CREATE TABLE IF NOT EXISTS access_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id TEXT NOT NULL UNIQUE, + client_id TEXT NOT NULL, + user_id TEXT NOT NULL, + scope TEXT, + expires_at TEXT NOT NULL, + created_at TEXT NOT NULL, + is_revoked BOOLEAN NOT NULL DEFAULT 0, + token_hash TEXT NOT NULL, + FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id) + )", + [], + )?; + + // Refresh Tokens table + self.conn.execute( + "CREATE TABLE IF NOT EXISTS refresh_tokens ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + token_id TEXT NOT NULL UNIQUE, + access_token_id INTEGER NOT NULL, + client_id TEXT NOT NULL, + user_id TEXT NOT NULL, + scope TEXT, + expires_at TEXT NOT NULL, + created_at TEXT NOT NULL, + is_revoked BOOLEAN NOT NULL DEFAULT 0, + token_hash TEXT NOT NULL, + FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id), + FOREIGN KEY (access_token_id) REFERENCES access_tokens (id) + )", + [], + )?; + + // RSA Keys table + self.conn.execute( + "CREATE TABLE IF NOT EXISTS rsa_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + kid TEXT NOT NULL UNIQUE, + private_key_pem TEXT NOT NULL, + public_key_pem TEXT NOT NULL, + created_at TEXT NOT NULL, + is_current BOOLEAN NOT NULL DEFAULT 0 + )", + [], + )?; + + // Audit Log table + self.conn.execute( + "CREATE TABLE IF NOT EXISTS audit_logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + event_type TEXT NOT NULL, + client_id TEXT, + user_id TEXT, + ip_address TEXT, + user_agent TEXT, + details TEXT, -- JSON + created_at TEXT NOT NULL, + success BOOLEAN NOT NULL + )", + [], + )?; + + // Rate Limiting table + self.conn.execute( + "CREATE TABLE IF NOT EXISTS rate_limits ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + identifier TEXT NOT NULL, -- client_id or IP + endpoint TEXT NOT NULL, + count INTEGER NOT NULL DEFAULT 1, + window_start TEXT NOT NULL, + created_at TEXT NOT NULL, + UNIQUE (identifier, endpoint, window_start) + )", + [], + )?; + + // Create indexes for performance + self.conn.execute( + "CREATE INDEX IF NOT EXISTS idx_auth_codes_client_id ON auth_codes (client_id)", + [], + )?; + self.conn.execute( + "CREATE INDEX IF NOT EXISTS idx_auth_codes_expires_at ON auth_codes (expires_at)", + [], + )?; + self.conn.execute( + "CREATE INDEX IF NOT EXISTS idx_access_tokens_client_id ON access_tokens (client_id)", + [], + )?; + self.conn.execute( + "CREATE INDEX IF NOT EXISTS idx_access_tokens_expires_at ON access_tokens (expires_at)", + [], + )?; + self.conn.execute( + "CREATE INDEX IF NOT EXISTS idx_refresh_tokens_client_id ON refresh_tokens (client_id)", + [], + )?; + self.conn.execute( + "CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs (created_at)", + [], + )?; + self.conn.execute( + "CREATE INDEX IF NOT EXISTS idx_rate_limits_identifier ON rate_limits (identifier, endpoint)", + [], + )?; + + Ok(()) + } + + // OAuth Client operations + pub fn create_oauth_client(&self, client: &DbOAuthClient) -> Result<i64> { + let mut stmt = self.conn.prepare( + "INSERT INTO oauth_clients + (client_id, client_secret_hash, client_name, redirect_uris, scopes, + grant_types, response_types, created_at, updated_at, is_active) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)" + )?; + + let id = stmt.insert(params![ + client.client_id, + client.client_secret_hash, + client.client_name, + client.redirect_uris, + client.scopes, + client.grant_types, + client.response_types, + client.created_at.to_rfc3339(), + client.updated_at.to_rfc3339(), + client.is_active + ])?; + + Ok(id) + } + + pub fn get_oauth_client(&self, client_id: &str) -> Result<Option<DbOAuthClient>> { + let mut stmt = self.conn.prepare( + "SELECT id, client_id, client_secret_hash, client_name, redirect_uris, + scopes, grant_types, response_types, created_at, updated_at, is_active + FROM oauth_clients WHERE client_id = ?1 AND is_active = 1" + )?; + + let client = stmt.query_row([client_id], |row| { + Ok(DbOAuthClient { + id: row.get(0)?, + client_id: row.get(1)?, + client_secret_hash: row.get(2)?, + client_name: row.get(3)?, + redirect_uris: row.get(4)?, + scopes: row.get(5)?, + grant_types: row.get(6)?, + response_types: row.get(7)?, + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(8, "created_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(9, "updated_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + is_active: row.get(10)?, + }) + }); + + match client { + Ok(client) => Ok(Some(client)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + + // Authorization Code operations + pub fn create_auth_code(&self, auth_code: &DbAuthCode) -> Result<i64> { + let mut stmt = self.conn.prepare( + "INSERT INTO auth_codes + (code, client_id, user_id, redirect_uri, scope, expires_at, created_at, + is_used, code_challenge, code_challenge_method) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)" + )?; + + let id = stmt.insert(params![ + auth_code.code, + auth_code.client_id, + auth_code.user_id, + auth_code.redirect_uri, + auth_code.scope, + auth_code.expires_at.to_rfc3339(), + auth_code.created_at.to_rfc3339(), + auth_code.is_used, + auth_code.code_challenge, + auth_code.code_challenge_method + ])?; + + Ok(id) + } + + pub fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>> { + let mut stmt = self.conn.prepare( + "SELECT id, code, client_id, user_id, redirect_uri, scope, expires_at, + created_at, is_used, code_challenge, code_challenge_method + FROM auth_codes WHERE code = ?1" + )?; + + let auth_code = stmt.query_row([code], |row| { + Ok(DbAuthCode { + id: row.get(0)?, + code: row.get(1)?, + client_id: row.get(2)?, + user_id: row.get(3)?, + redirect_uri: row.get(4)?, + scope: row.get(5)?, + expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(6, "expires_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(7, "created_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + is_used: row.get(8)?, + code_challenge: row.get(9)?, + code_challenge_method: row.get(10)?, + }) + }); + + match auth_code { + Ok(code) => Ok(Some(code)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + + pub fn mark_auth_code_used(&self, code: &str) -> Result<()> { + self.conn.execute( + "UPDATE auth_codes SET is_used = 1 WHERE code = ?1", + [code], + )?; + Ok(()) + } + + // Access Token operations + pub fn create_access_token(&self, token: &DbAccessToken) -> Result<i64> { + let mut stmt = self.conn.prepare( + "INSERT INTO access_tokens + (token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" + )?; + + let id = stmt.insert(params![ + token.token_id, + token.client_id, + token.user_id, + token.scope, + token.expires_at.to_rfc3339(), + token.created_at.to_rfc3339(), + token.is_revoked, + token.token_hash + ])?; + + Ok(id) + } + + pub fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>> { + let mut stmt = self.conn.prepare( + "SELECT id, token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash + FROM access_tokens WHERE token_hash = ?1" + )?; + + let token = stmt.query_row([token_hash], |row| { + Ok(DbAccessToken { + id: row.get(0)?, + token_id: row.get(1)?, + client_id: row.get(2)?, + user_id: row.get(3)?, + scope: row.get(4)?, + expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(5, "expires_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(6, "created_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + is_revoked: row.get(7)?, + token_hash: row.get(8)?, + }) + }); + + match token { + Ok(token) => Ok(Some(token)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + + pub fn revoke_access_token(&self, token_hash: &str) -> Result<()> { + self.conn.execute( + "UPDATE access_tokens SET is_revoked = 1 WHERE token_hash = ?1", + [token_hash], + )?; + Ok(()) + } + + // RSA Key operations + pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result<i64> { + let mut stmt = self.conn.prepare( + "INSERT INTO rsa_keys (kid, private_key_pem, public_key_pem, created_at, is_current) + VALUES (?1, ?2, ?3, ?4, ?5)" + )?; + + let id = stmt.insert(params![ + key.kid, + key.private_key_pem, + key.public_key_pem, + key.created_at.to_rfc3339(), + key.is_current + ])?; + + Ok(id) + } + + pub fn get_current_rsa_key(&self) -> Result<Option<DbRsaKey>> { + let mut stmt = self.conn.prepare( + "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current + FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1" + )?; + + let key = stmt.query_row([], |row| { + Ok(DbRsaKey { + id: row.get(0)?, + kid: row.get(1)?, + private_key_pem: row.get(2)?, + public_key_pem: row.get(3)?, + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + is_current: row.get(5)?, + }) + }); + + match key { + Ok(key) => Ok(Some(key)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + + pub fn get_all_rsa_keys(&self) -> Result<Vec<DbRsaKey>> { + let mut stmt = self.conn.prepare( + "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current + FROM rsa_keys ORDER BY created_at DESC" + )?; + + let keys = stmt.query_map([], |row| { + Ok(DbRsaKey { + id: row.get(0)?, + kid: row.get(1)?, + private_key_pem: row.get(2)?, + public_key_pem: row.get(3)?, + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?) + .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))? + .with_timezone(&Utc), + is_current: row.get(5)?, + }) + })?; + + let mut result = Vec::new(); + for key in keys { + result.push(key?); + } + Ok(result) + } + + pub fn set_current_rsa_key(&self, kid: &str) -> Result<()> { + // First, unset all current keys + self.conn.execute("UPDATE rsa_keys SET is_current = 0", [])?; + + // Then set the specified key as current + self.conn.execute( + "UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1", + [kid], + )?; + + Ok(()) + } + + // Audit Log operations + pub fn create_audit_log(&self, log: &DbAuditLog) -> Result<i64> { + let mut stmt = self.conn.prepare( + "INSERT INTO audit_logs + (event_type, client_id, user_id, ip_address, user_agent, details, created_at, success) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" + )?; + + let id = stmt.insert(params![ + log.event_type, + log.client_id, + log.user_id, + log.ip_address, + log.user_agent, + log.details, + log.created_at.to_rfc3339(), + log.success + ])?; + + Ok(id) + } + + // Rate Limiting operations + pub fn increment_rate_limit(&self, identifier: &str, endpoint: &str, window_minutes: i32) -> Result<i32> { + let now = Utc::now(); + let window_start = now - chrono::Duration::minutes(window_minutes as i64); + + // Try to increment existing counter in current window + let affected = self.conn.execute( + "UPDATE rate_limits SET count = count + 1 + WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3", + params![identifier, endpoint, window_start.to_rfc3339()], + )?; + + if affected == 0 { + // No existing record, create new one + self.conn.execute( + "INSERT OR REPLACE INTO rate_limits (identifier, endpoint, count, window_start, created_at) + VALUES (?1, ?2, 1, ?3, ?4)", + params![identifier, endpoint, now.to_rfc3339(), now.to_rfc3339()], + )?; + Ok(1) + } else { + // Return current count + let count: i32 = self.conn.query_row( + "SELECT count FROM rate_limits + WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3", + params![identifier, endpoint, window_start.to_rfc3339()], + |row| row.get(0), + )?; + Ok(count) + } + } + + // Cleanup operations + pub fn cleanup_expired_codes(&self) -> Result<usize> { + let now = Utc::now(); + let affected = self.conn.execute( + "DELETE FROM auth_codes WHERE expires_at < ?1", + [now.to_rfc3339()], + )?; + Ok(affected) + } + + pub fn cleanup_expired_tokens(&self) -> Result<usize> { + let now = Utc::now(); + let affected = self.conn.execute( + "DELETE FROM access_tokens WHERE expires_at < ?1 AND is_revoked = 1", + [now.to_rfc3339()], + )?; + Ok(affected) + } + + pub fn cleanup_old_audit_logs(&self, days: i32) -> Result<usize> { + let cutoff = Utc::now() - chrono::Duration::days(days as i64); + let affected = self.conn.execute( + "DELETE FROM audit_logs WHERE created_at < ?1", + [cutoff.to_rfc3339()], + )?; + Ok(affected) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_database_creation() { + let _db = Database::new_in_memory().expect("Failed to create database"); + assert!(true); // If we got here, database was created successfully + } + + #[test] + fn test_oauth_client_operations() { + let db = Database::new_in_memory().expect("Failed to create database"); + + let client = DbOAuthClient { + id: 0, + client_id: "test_client".to_string(), + client_secret_hash: "hash123".to_string(), + client_name: "Test Client".to_string(), + redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(), + scopes: "openid profile".to_string(), + grant_types: "authorization_code".to_string(), + response_types: "code".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + is_active: true, + }; + + let id = db.create_oauth_client(&client).expect("Failed to create client"); + assert!(id > 0); + + let retrieved = db.get_oauth_client("test_client").expect("Failed to get client"); + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().client_name, "Test Client"); + } + + #[test] + fn test_auth_code_operations() { + let db = Database::new_in_memory().expect("Failed to create database"); + + // First create a client (required for foreign key constraint) + let client = DbOAuthClient { + id: 0, + client_id: "test_client".to_string(), + client_secret_hash: "hash123".to_string(), + client_name: "Test Client".to_string(), + redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(), + scopes: "openid profile".to_string(), + grant_types: "authorization_code".to_string(), + response_types: "code".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + is_active: true, + }; + db.create_oauth_client(&client).expect("Failed to create client"); + + let auth_code = DbAuthCode { + id: 0, + code: "test_code_123".to_string(), + client_id: "test_client".to_string(), + user_id: "test_user".to_string(), + redirect_uri: "http://localhost:3000/callback".to_string(), + scope: Some("openid".to_string()), + expires_at: Utc::now() + chrono::Duration::minutes(10), + created_at: Utc::now(), + is_used: false, + code_challenge: Some("challenge123".to_string()), + code_challenge_method: Some("S256".to_string()), + }; + + let id = db.create_auth_code(&auth_code).expect("Failed to create auth code"); + assert!(id > 0); + + let retrieved = db.get_auth_code("test_code_123").expect("Failed to get auth code"); + assert!(retrieved.is_some()); + let code = retrieved.unwrap(); + assert_eq!(code.client_id, "test_client"); + assert_eq!(code.is_used, false); + + db.mark_auth_code_used("test_code_123").expect("Failed to mark code as used"); + let updated = db.get_auth_code("test_code_123").expect("Failed to get auth code"); + assert_eq!(updated.unwrap().is_used, true); + } +}
\ No newline at end of file |
