summaryrefslogtreecommitdiff
path: root/src/database.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/database.rs')
-rw-r--r--src/database.rs703
1 files changed, 703 insertions, 0 deletions
diff --git a/src/database.rs b/src/database.rs
new file mode 100644
index 0000000..dc33cf8
--- /dev/null
+++ b/src/database.rs
@@ -0,0 +1,703 @@
+use anyhow::Result;
+use chrono::{DateTime, Utc};
+use rusqlite::{params, Connection};
+use serde::{Deserialize, Serialize};
+use std::path::Path;
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbOAuthClient {
+ pub id: i64,
+ pub client_id: String,
+ pub client_secret_hash: String,
+ pub client_name: String,
+ pub redirect_uris: String, // JSON array
+ pub scopes: String, // Space-separated
+ pub grant_types: String, // Space-separated
+ pub response_types: String, // Space-separated
+ pub created_at: DateTime<Utc>,
+ pub updated_at: DateTime<Utc>,
+ pub is_active: bool,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbAuthCode {
+ pub id: i64,
+ pub code: String,
+ pub client_id: String,
+ pub user_id: String,
+ pub redirect_uri: String,
+ pub scope: Option<String>,
+ pub expires_at: DateTime<Utc>,
+ pub created_at: DateTime<Utc>,
+ pub is_used: bool,
+ // PKCE fields
+ pub code_challenge: Option<String>,
+ pub code_challenge_method: Option<String>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbAccessToken {
+ pub id: i64,
+ pub token_id: String,
+ pub client_id: String,
+ pub user_id: String,
+ pub scope: Option<String>,
+ pub expires_at: DateTime<Utc>,
+ pub created_at: DateTime<Utc>,
+ pub is_revoked: bool,
+ pub token_hash: String, // For revocation lookup
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbRefreshToken {
+ pub id: i64,
+ pub token_id: String,
+ pub access_token_id: i64,
+ pub client_id: String,
+ pub user_id: String,
+ pub scope: Option<String>,
+ pub expires_at: DateTime<Utc>,
+ pub created_at: DateTime<Utc>,
+ pub is_revoked: bool,
+ pub token_hash: String,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbRsaKey {
+ pub id: i64,
+ pub kid: String,
+ pub private_key_pem: String,
+ pub public_key_pem: String,
+ pub created_at: DateTime<Utc>,
+ pub is_current: bool,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbAuditLog {
+ pub id: i64,
+ pub event_type: String,
+ pub client_id: Option<String>,
+ pub user_id: Option<String>,
+ pub ip_address: Option<String>,
+ pub user_agent: Option<String>,
+ pub details: Option<String>, // JSON
+ pub created_at: DateTime<Utc>,
+ pub success: bool,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DbRateLimit {
+ pub id: i64,
+ pub identifier: String, // client_id or IP address
+ pub endpoint: String,
+ pub count: i32,
+ pub window_start: DateTime<Utc>,
+ pub created_at: DateTime<Utc>,
+}
+
+pub struct Database {
+ conn: Connection,
+}
+
+impl Database {
+ pub fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
+ let conn = Connection::open(path)?;
+ let db = Self { conn };
+ db.initialize_schema()?;
+ Ok(db)
+ }
+
+ pub fn new_in_memory() -> Result<Self> {
+ let conn = Connection::open_in_memory()?;
+ let db = Self { conn };
+ db.initialize_schema()?;
+ Ok(db)
+ }
+
+ fn initialize_schema(&self) -> Result<()> {
+ // OAuth Clients table
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS oauth_clients (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ client_id TEXT NOT NULL UNIQUE,
+ client_secret_hash TEXT NOT NULL,
+ client_name TEXT NOT NULL,
+ redirect_uris TEXT NOT NULL, -- JSON array
+ scopes TEXT NOT NULL, -- space-separated
+ grant_types TEXT NOT NULL, -- space-separated
+ response_types TEXT NOT NULL, -- space-separated
+ created_at TEXT NOT NULL,
+ updated_at TEXT NOT NULL,
+ is_active BOOLEAN NOT NULL DEFAULT 1
+ )",
+ [],
+ )?;
+
+ // Authorization Codes table
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS auth_codes (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ code TEXT NOT NULL UNIQUE,
+ client_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ redirect_uri TEXT NOT NULL,
+ scope TEXT,
+ expires_at TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ is_used BOOLEAN NOT NULL DEFAULT 0,
+ code_challenge TEXT,
+ code_challenge_method TEXT,
+ FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id)
+ )",
+ [],
+ )?;
+
+ // Access Tokens table
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS access_tokens (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ token_id TEXT NOT NULL UNIQUE,
+ client_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ scope TEXT,
+ expires_at TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ is_revoked BOOLEAN NOT NULL DEFAULT 0,
+ token_hash TEXT NOT NULL,
+ FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id)
+ )",
+ [],
+ )?;
+
+ // Refresh Tokens table
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS refresh_tokens (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ token_id TEXT NOT NULL UNIQUE,
+ access_token_id INTEGER NOT NULL,
+ client_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ scope TEXT,
+ expires_at TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ is_revoked BOOLEAN NOT NULL DEFAULT 0,
+ token_hash TEXT NOT NULL,
+ FOREIGN KEY (client_id) REFERENCES oauth_clients (client_id),
+ FOREIGN KEY (access_token_id) REFERENCES access_tokens (id)
+ )",
+ [],
+ )?;
+
+ // RSA Keys table
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS rsa_keys (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ kid TEXT NOT NULL UNIQUE,
+ private_key_pem TEXT NOT NULL,
+ public_key_pem TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ is_current BOOLEAN NOT NULL DEFAULT 0
+ )",
+ [],
+ )?;
+
+ // Audit Log table
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS audit_logs (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ event_type TEXT NOT NULL,
+ client_id TEXT,
+ user_id TEXT,
+ ip_address TEXT,
+ user_agent TEXT,
+ details TEXT, -- JSON
+ created_at TEXT NOT NULL,
+ success BOOLEAN NOT NULL
+ )",
+ [],
+ )?;
+
+ // Rate Limiting table
+ self.conn.execute(
+ "CREATE TABLE IF NOT EXISTS rate_limits (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ identifier TEXT NOT NULL, -- client_id or IP
+ endpoint TEXT NOT NULL,
+ count INTEGER NOT NULL DEFAULT 1,
+ window_start TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ UNIQUE (identifier, endpoint, window_start)
+ )",
+ [],
+ )?;
+
+ // Create indexes for performance
+ self.conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_auth_codes_client_id ON auth_codes (client_id)",
+ [],
+ )?;
+ self.conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_auth_codes_expires_at ON auth_codes (expires_at)",
+ [],
+ )?;
+ self.conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_access_tokens_client_id ON access_tokens (client_id)",
+ [],
+ )?;
+ self.conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_access_tokens_expires_at ON access_tokens (expires_at)",
+ [],
+ )?;
+ self.conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_refresh_tokens_client_id ON refresh_tokens (client_id)",
+ [],
+ )?;
+ self.conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_audit_logs_created_at ON audit_logs (created_at)",
+ [],
+ )?;
+ self.conn.execute(
+ "CREATE INDEX IF NOT EXISTS idx_rate_limits_identifier ON rate_limits (identifier, endpoint)",
+ [],
+ )?;
+
+ Ok(())
+ }
+
+ // OAuth Client operations
+ pub fn create_oauth_client(&self, client: &DbOAuthClient) -> Result<i64> {
+ let mut stmt = self.conn.prepare(
+ "INSERT INTO oauth_clients
+ (client_id, client_secret_hash, client_name, redirect_uris, scopes,
+ grant_types, response_types, created_at, updated_at, is_active)
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)"
+ )?;
+
+ let id = stmt.insert(params![
+ client.client_id,
+ client.client_secret_hash,
+ client.client_name,
+ client.redirect_uris,
+ client.scopes,
+ client.grant_types,
+ client.response_types,
+ client.created_at.to_rfc3339(),
+ client.updated_at.to_rfc3339(),
+ client.is_active
+ ])?;
+
+ Ok(id)
+ }
+
+ pub fn get_oauth_client(&self, client_id: &str) -> Result<Option<DbOAuthClient>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, client_id, client_secret_hash, client_name, redirect_uris,
+ scopes, grant_types, response_types, created_at, updated_at, is_active
+ FROM oauth_clients WHERE client_id = ?1 AND is_active = 1"
+ )?;
+
+ let client = stmt.query_row([client_id], |row| {
+ Ok(DbOAuthClient {
+ id: row.get(0)?,
+ client_id: row.get(1)?,
+ client_secret_hash: row.get(2)?,
+ client_name: row.get(3)?,
+ redirect_uris: row.get(4)?,
+ scopes: row.get(5)?,
+ grant_types: row.get(6)?,
+ response_types: row.get(7)?,
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(8, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(9, "updated_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ is_active: row.get(10)?,
+ })
+ });
+
+ match client {
+ Ok(client) => Ok(Some(client)),
+ Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
+ Err(e) => Err(e.into()),
+ }
+ }
+
+ // Authorization Code operations
+ pub fn create_auth_code(&self, auth_code: &DbAuthCode) -> Result<i64> {
+ let mut stmt = self.conn.prepare(
+ "INSERT INTO auth_codes
+ (code, client_id, user_id, redirect_uri, scope, expires_at, created_at,
+ is_used, code_challenge, code_challenge_method)
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)"
+ )?;
+
+ let id = stmt.insert(params![
+ auth_code.code,
+ auth_code.client_id,
+ auth_code.user_id,
+ auth_code.redirect_uri,
+ auth_code.scope,
+ auth_code.expires_at.to_rfc3339(),
+ auth_code.created_at.to_rfc3339(),
+ auth_code.is_used,
+ auth_code.code_challenge,
+ auth_code.code_challenge_method
+ ])?;
+
+ Ok(id)
+ }
+
+ pub fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, code, client_id, user_id, redirect_uri, scope, expires_at,
+ created_at, is_used, code_challenge, code_challenge_method
+ FROM auth_codes WHERE code = ?1"
+ )?;
+
+ let auth_code = stmt.query_row([code], |row| {
+ Ok(DbAuthCode {
+ id: row.get(0)?,
+ code: row.get(1)?,
+ client_id: row.get(2)?,
+ user_id: row.get(3)?,
+ redirect_uri: row.get(4)?,
+ scope: row.get(5)?,
+ expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(6, "expires_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(7, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ is_used: row.get(8)?,
+ code_challenge: row.get(9)?,
+ code_challenge_method: row.get(10)?,
+ })
+ });
+
+ match auth_code {
+ Ok(code) => Ok(Some(code)),
+ Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
+ Err(e) => Err(e.into()),
+ }
+ }
+
+ pub fn mark_auth_code_used(&self, code: &str) -> Result<()> {
+ self.conn.execute(
+ "UPDATE auth_codes SET is_used = 1 WHERE code = ?1",
+ [code],
+ )?;
+ Ok(())
+ }
+
+ // Access Token operations
+ pub fn create_access_token(&self, token: &DbAccessToken) -> Result<i64> {
+ let mut stmt = self.conn.prepare(
+ "INSERT INTO access_tokens
+ (token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash)
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
+ )?;
+
+ let id = stmt.insert(params![
+ token.token_id,
+ token.client_id,
+ token.user_id,
+ token.scope,
+ token.expires_at.to_rfc3339(),
+ token.created_at.to_rfc3339(),
+ token.is_revoked,
+ token.token_hash
+ ])?;
+
+ Ok(id)
+ }
+
+ pub fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash
+ FROM access_tokens WHERE token_hash = ?1"
+ )?;
+
+ let token = stmt.query_row([token_hash], |row| {
+ Ok(DbAccessToken {
+ id: row.get(0)?,
+ token_id: row.get(1)?,
+ client_id: row.get(2)?,
+ user_id: row.get(3)?,
+ scope: row.get(4)?,
+ expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(5, "expires_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(6, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ is_revoked: row.get(7)?,
+ token_hash: row.get(8)?,
+ })
+ });
+
+ match token {
+ Ok(token) => Ok(Some(token)),
+ Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
+ Err(e) => Err(e.into()),
+ }
+ }
+
+ pub fn revoke_access_token(&self, token_hash: &str) -> Result<()> {
+ self.conn.execute(
+ "UPDATE access_tokens SET is_revoked = 1 WHERE token_hash = ?1",
+ [token_hash],
+ )?;
+ Ok(())
+ }
+
+ // RSA Key operations
+ pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result<i64> {
+ let mut stmt = self.conn.prepare(
+ "INSERT INTO rsa_keys (kid, private_key_pem, public_key_pem, created_at, is_current)
+ VALUES (?1, ?2, ?3, ?4, ?5)"
+ )?;
+
+ let id = stmt.insert(params![
+ key.kid,
+ key.private_key_pem,
+ key.public_key_pem,
+ key.created_at.to_rfc3339(),
+ key.is_current
+ ])?;
+
+ Ok(id)
+ }
+
+ pub fn get_current_rsa_key(&self) -> Result<Option<DbRsaKey>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current
+ FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1"
+ )?;
+
+ let key = stmt.query_row([], |row| {
+ Ok(DbRsaKey {
+ id: row.get(0)?,
+ kid: row.get(1)?,
+ private_key_pem: row.get(2)?,
+ public_key_pem: row.get(3)?,
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ is_current: row.get(5)?,
+ })
+ });
+
+ match key {
+ Ok(key) => Ok(Some(key)),
+ Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
+ Err(e) => Err(e.into()),
+ }
+ }
+
+ pub fn get_all_rsa_keys(&self) -> Result<Vec<DbRsaKey>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current
+ FROM rsa_keys ORDER BY created_at DESC"
+ )?;
+
+ let keys = stmt.query_map([], |row| {
+ Ok(DbRsaKey {
+ id: row.get(0)?,
+ kid: row.get(1)?,
+ private_key_pem: row.get(2)?,
+ public_key_pem: row.get(3)?,
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
+ .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .with_timezone(&Utc),
+ is_current: row.get(5)?,
+ })
+ })?;
+
+ let mut result = Vec::new();
+ for key in keys {
+ result.push(key?);
+ }
+ Ok(result)
+ }
+
+ pub fn set_current_rsa_key(&self, kid: &str) -> Result<()> {
+ // First, unset all current keys
+ self.conn.execute("UPDATE rsa_keys SET is_current = 0", [])?;
+
+ // Then set the specified key as current
+ self.conn.execute(
+ "UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1",
+ [kid],
+ )?;
+
+ Ok(())
+ }
+
+ // Audit Log operations
+ pub fn create_audit_log(&self, log: &DbAuditLog) -> Result<i64> {
+ let mut stmt = self.conn.prepare(
+ "INSERT INTO audit_logs
+ (event_type, client_id, user_id, ip_address, user_agent, details, created_at, success)
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
+ )?;
+
+ let id = stmt.insert(params![
+ log.event_type,
+ log.client_id,
+ log.user_id,
+ log.ip_address,
+ log.user_agent,
+ log.details,
+ log.created_at.to_rfc3339(),
+ log.success
+ ])?;
+
+ Ok(id)
+ }
+
+ // Rate Limiting operations
+ pub fn increment_rate_limit(&self, identifier: &str, endpoint: &str, window_minutes: i32) -> Result<i32> {
+ let now = Utc::now();
+ let window_start = now - chrono::Duration::minutes(window_minutes as i64);
+
+ // Try to increment existing counter in current window
+ let affected = self.conn.execute(
+ "UPDATE rate_limits SET count = count + 1
+ WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3",
+ params![identifier, endpoint, window_start.to_rfc3339()],
+ )?;
+
+ if affected == 0 {
+ // No existing record, create new one
+ self.conn.execute(
+ "INSERT OR REPLACE INTO rate_limits (identifier, endpoint, count, window_start, created_at)
+ VALUES (?1, ?2, 1, ?3, ?4)",
+ params![identifier, endpoint, now.to_rfc3339(), now.to_rfc3339()],
+ )?;
+ Ok(1)
+ } else {
+ // Return current count
+ let count: i32 = self.conn.query_row(
+ "SELECT count FROM rate_limits
+ WHERE identifier = ?1 AND endpoint = ?2 AND window_start >= ?3",
+ params![identifier, endpoint, window_start.to_rfc3339()],
+ |row| row.get(0),
+ )?;
+ Ok(count)
+ }
+ }
+
+ // Cleanup operations
+ pub fn cleanup_expired_codes(&self) -> Result<usize> {
+ let now = Utc::now();
+ let affected = self.conn.execute(
+ "DELETE FROM auth_codes WHERE expires_at < ?1",
+ [now.to_rfc3339()],
+ )?;
+ Ok(affected)
+ }
+
+ pub fn cleanup_expired_tokens(&self) -> Result<usize> {
+ let now = Utc::now();
+ let affected = self.conn.execute(
+ "DELETE FROM access_tokens WHERE expires_at < ?1 AND is_revoked = 1",
+ [now.to_rfc3339()],
+ )?;
+ Ok(affected)
+ }
+
+ pub fn cleanup_old_audit_logs(&self, days: i32) -> Result<usize> {
+ let cutoff = Utc::now() - chrono::Duration::days(days as i64);
+ let affected = self.conn.execute(
+ "DELETE FROM audit_logs WHERE created_at < ?1",
+ [cutoff.to_rfc3339()],
+ )?;
+ Ok(affected)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_database_creation() {
+ let _db = Database::new_in_memory().expect("Failed to create database");
+ assert!(true); // If we got here, database was created successfully
+ }
+
+ #[test]
+ fn test_oauth_client_operations() {
+ let db = Database::new_in_memory().expect("Failed to create database");
+
+ let client = DbOAuthClient {
+ id: 0,
+ client_id: "test_client".to_string(),
+ client_secret_hash: "hash123".to_string(),
+ client_name: "Test Client".to_string(),
+ redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(),
+ scopes: "openid profile".to_string(),
+ grant_types: "authorization_code".to_string(),
+ response_types: "code".to_string(),
+ created_at: Utc::now(),
+ updated_at: Utc::now(),
+ is_active: true,
+ };
+
+ let id = db.create_oauth_client(&client).expect("Failed to create client");
+ assert!(id > 0);
+
+ let retrieved = db.get_oauth_client("test_client").expect("Failed to get client");
+ assert!(retrieved.is_some());
+ assert_eq!(retrieved.unwrap().client_name, "Test Client");
+ }
+
+ #[test]
+ fn test_auth_code_operations() {
+ let db = Database::new_in_memory().expect("Failed to create database");
+
+ // First create a client (required for foreign key constraint)
+ let client = DbOAuthClient {
+ id: 0,
+ client_id: "test_client".to_string(),
+ client_secret_hash: "hash123".to_string(),
+ client_name: "Test Client".to_string(),
+ redirect_uris: "[\"http://localhost:3000/callback\"]".to_string(),
+ scopes: "openid profile".to_string(),
+ grant_types: "authorization_code".to_string(),
+ response_types: "code".to_string(),
+ created_at: Utc::now(),
+ updated_at: Utc::now(),
+ is_active: true,
+ };
+ db.create_oauth_client(&client).expect("Failed to create client");
+
+ let auth_code = DbAuthCode {
+ id: 0,
+ code: "test_code_123".to_string(),
+ client_id: "test_client".to_string(),
+ user_id: "test_user".to_string(),
+ redirect_uri: "http://localhost:3000/callback".to_string(),
+ scope: Some("openid".to_string()),
+ expires_at: Utc::now() + chrono::Duration::minutes(10),
+ created_at: Utc::now(),
+ is_used: false,
+ code_challenge: Some("challenge123".to_string()),
+ code_challenge_method: Some("S256".to_string()),
+ };
+
+ let id = db.create_auth_code(&auth_code).expect("Failed to create auth code");
+ assert!(id > 0);
+
+ let retrieved = db.get_auth_code("test_code_123").expect("Failed to get auth code");
+ assert!(retrieved.is_some());
+ let code = retrieved.unwrap();
+ assert_eq!(code.client_id, "test_client");
+ assert_eq!(code.is_used, false);
+
+ db.mark_auth_code_used("test_code_123").expect("Failed to mark code as used");
+ let updated = db.get_auth_code("test_code_123").expect("Failed to get auth code");
+ assert_eq!(updated.unwrap().is_used, true);
+ }
+} \ No newline at end of file