summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 15:15:41 -0600
committermo khan <mo@mokhan.ca>2025-06-11 15:15:41 -0600
commitaea6bd6ec7d7e70a67723edf6327df4a9cc65d89 (patch)
tree80fcb6cbda7baa5ed15cf044d7583acb2438c4d2
parent4435ee26b79648e92d0f172e42f9e6629e955505 (diff)
chore: run rustfmt again
-rw-r--r--src/clients.rs112
-rw-r--r--src/config.rs6
-rw-r--r--src/database.rs146
-rw-r--r--src/keys.rs28
-rw-r--r--src/main.rs49
-rw-r--r--src/migrations.rs46
6 files changed, 247 insertions, 140 deletions
diff --git a/src/clients.rs b/src/clients.rs
index bc9aea5..7941d8c 100644
--- a/src/clients.rs
+++ b/src/clients.rs
@@ -1,12 +1,12 @@
+use crate::database::{Database, DbOAuthClient};
+use anyhow::Result;
use base64::Engine;
+use chrono::Utc;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use uuid::Uuid;
-use chrono::Utc;
-use crate::database::{Database, DbOAuthClient};
-use anyhow::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthClient {
@@ -37,10 +37,10 @@ impl ClientManager {
clients: HashMap::new(),
database: database.clone(),
};
-
+
// Load existing clients from database into cache
manager.load_clients_from_db()?;
-
+
// Register a default test client for development if it doesn't exist
if manager.get_client_from_db("test_client")?.is_none() {
let _ = manager.register_client(
@@ -48,13 +48,17 @@ impl ClientManager {
"test_secret".to_string(),
vec!["http://localhost:3000/callback".to_string()],
"Test Client".to_string(),
- vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ vec![
+ "openid".to_string(),
+ "profile".to_string(),
+ "email".to_string(),
+ ],
); // Ignore errors if client already exists
}
-
+
Ok(manager)
}
-
+
fn load_clients_from_db(&mut self) -> Result<()> {
// This is a simplified version - in practice you'd want to load all clients
// For now we'll load on-demand
@@ -129,40 +133,52 @@ impl ClientManager {
if let Some(client) = self.clients.get(client_id) {
return Some(client);
}
-
+
// If not in cache, try to load from database
// For thread safety, we can't mutate self here, so we'll return None
// In a real implementation, you'd want a more sophisticated caching strategy
None
}
-
+
pub fn get_client_from_db(&mut self, client_id: &str) -> Result<Option<OAuthClient>> {
// Check cache first
if let Some(client) = self.clients.get(client_id) {
return Ok(Some(client.clone()));
}
-
+
// Load from database
let db_client = {
let db = self.database.lock().unwrap();
db.get_oauth_client(client_id)?
};
-
+
if let Some(db_client) = db_client {
let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?;
- let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect();
-
+ let scopes: Vec<String> = db_client
+ .scopes
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
+
let client = OAuthClient {
client_id: db_client.client_id.clone(),
client_secret_hash: db_client.client_secret_hash,
redirect_uris,
client_name: db_client.client_name,
scopes,
- grant_types: db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect(),
- response_types: db_client.response_types.split_whitespace().map(|s| s.to_string()).collect(),
+ grant_types: db_client
+ .grant_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect(),
+ response_types: db_client
+ .response_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect(),
created_at: db_client.created_at.timestamp() as u64,
};
-
+
// Cache it
self.clients.insert(db_client.client_id, client.clone());
Ok(Some(client))
@@ -181,7 +197,7 @@ impl ClientManager {
return false;
}
};
-
+
let provided_hash = Self::hash_secret(client_secret);
// Use constant-time comparison to prevent timing attacks
self.constant_time_eq(&client.client_secret_hash, &provided_hash)
@@ -199,7 +215,9 @@ impl ClientManager {
if let Ok(Some(client)) = self.get_client_from_db(client_id) {
if let Some(scopes_str) = requested_scopes {
let requested: Vec<&str> = scopes_str.split_whitespace().collect();
- requested.iter().all(|scope| client.scopes.contains(&scope.to_string()))
+ requested
+ .iter()
+ .all(|scope| client.scopes.contains(&scope.to_string()))
} else {
true // No scopes requested is valid
}
@@ -224,7 +242,7 @@ impl ClientManager {
if a.len() != b.len() {
return false;
}
-
+
let mut result = 0u8;
for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
result |= byte_a ^ byte_b;
@@ -232,16 +250,24 @@ impl ClientManager {
result == 0
}
- pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec<String>) -> Result<ClientCredentials> {
+ pub fn generate_client_credentials(
+ &mut self,
+ client_name: String,
+ redirect_uris: Vec<String>,
+ ) -> Result<ClientCredentials> {
let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", ""));
let client_secret = Uuid::new_v4().to_string();
-
+
self.register_client(
client_id.clone(),
client_secret.clone(),
redirect_uris,
client_name,
- vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ vec![
+ "openid".to_string(),
+ "profile".to_string(),
+ "email".to_string(),
+ ],
)?;
Ok(ClientCredentials {
@@ -253,7 +279,7 @@ impl ClientManager {
pub fn list_clients(&self) -> Vec<&OAuthClient> {
self.clients.values().collect()
}
-
+
pub fn list_all_clients_from_db(&self) -> Result<Vec<DbOAuthClient>> {
// This would require a new database method - for now return empty
Ok(vec![])
@@ -270,13 +296,13 @@ pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> {
let decoded = base64::engine::general_purpose::STANDARD
.decode(encoded)
.ok()?;
-
+
let credentials = String::from_utf8(decoded).ok()?;
let mut parts = credentials.splitn(2, ':');
-
+
let username = parts.next()?.to_string();
let password = parts.next()?.to_string();
-
+
Some((username, password))
}
@@ -288,7 +314,7 @@ mod disabled_tests {
#[test]
fn test_client_registration() {
let mut manager = ClientManager::new();
-
+
let result = manager.register_client(
"new_client".to_string(),
"secret123".to_string(),
@@ -296,7 +322,7 @@ mod disabled_tests {
"Test App".to_string(),
vec!["openid".to_string()],
);
-
+
assert!(result.is_ok());
let client = result.unwrap();
assert_eq!(client.client_id, "new_client");
@@ -306,7 +332,7 @@ mod disabled_tests {
#[test]
fn test_duplicate_client_id() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"duplicate".to_string(),
"secret1".to_string(),
@@ -314,7 +340,7 @@ mod disabled_tests {
"App 1".to_string(),
vec!["openid".to_string()],
).unwrap();
-
+
let result = manager.register_client(
"duplicate".to_string(),
"secret2".to_string(),
@@ -322,14 +348,14 @@ mod disabled_tests {
"App 2".to_string(),
vec!["openid".to_string()],
);
-
+
assert!(result.is_err());
}
#[test]
fn test_client_authentication() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"auth_client".to_string(),
"correct_secret".to_string(),
@@ -337,7 +363,7 @@ mod disabled_tests {
"Auth Test".to_string(),
vec!["openid".to_string()],
).unwrap();
-
+
assert!(manager.authenticate_client("auth_client", "correct_secret"));
assert!(!manager.authenticate_client("auth_client", "wrong_secret"));
assert!(!manager.authenticate_client("nonexistent", "any_secret"));
@@ -346,7 +372,7 @@ mod disabled_tests {
#[test]
fn test_redirect_uri_validation() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"uri_client".to_string(),
"secret".to_string(),
@@ -354,7 +380,7 @@ mod disabled_tests {
"URI Test".to_string(),
vec!["openid".to_string()],
).unwrap();
-
+
assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback"));
assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback"));
assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback"));
@@ -364,7 +390,7 @@ mod disabled_tests {
#[test]
fn test_scope_validation() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"scope_client".to_string(),
"secret".to_string(),
@@ -372,7 +398,7 @@ mod disabled_tests {
"Scope Test".to_string(),
vec!["openid".to_string(), "profile".to_string()],
).unwrap();
-
+
assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string())));
assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string())));
assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string())));
@@ -383,7 +409,7 @@ mod disabled_tests {
fn test_basic_auth_parsing() {
let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret
let result = parse_basic_auth(auth_header);
-
+
assert!(result.is_some());
let (username, password) = result.unwrap();
assert_eq!(username, "test_client");
@@ -393,19 +419,19 @@ mod disabled_tests {
#[test]
fn test_generate_client_credentials() {
let mut manager = ClientManager::new();
-
+
let result = manager.generate_client_credentials(
"Generated App".to_string(),
vec!["https://generated.com/callback".to_string()],
);
-
+
assert!(result.is_ok());
let credentials = result.unwrap();
assert!(credentials.client_id.starts_with("client_"));
assert!(!credentials.client_secret.is_empty());
-
+
// Verify the client was actually registered
assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret));
}
}
-*/ \ No newline at end of file
+*/
diff --git a/src/config.rs b/src/config.rs
index 266f669..e496581 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -13,8 +13,10 @@ pub struct Config {
impl Config {
pub fn from_env() -> Self {
let bind_addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:7878".to_string());
- let issuer_url = std::env::var("ISSUER_URL").unwrap_or_else(|_| format!("http://{}", bind_addr));
- let database_path = std::env::var("DATABASE_PATH").unwrap_or_else(|_| "oauth.db".to_string());
+ let issuer_url =
+ std::env::var("ISSUER_URL").unwrap_or_else(|_| format!("http://{}", bind_addr));
+ let database_path =
+ std::env::var("DATABASE_PATH").unwrap_or_else(|_| "oauth.db".to_string());
let rate_limit_requests_per_minute = std::env::var("RATE_LIMIT_RPM")
.unwrap_or_else(|_| "60".to_string())
.parse()
diff --git a/src/database.rs b/src/database.rs
index dc33cf8..2472d1a 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -1,6 +1,6 @@
use anyhow::Result;
use chrono::{DateTime, Utc};
-use rusqlite::{params, Connection};
+use rusqlite::{Connection, params};
use serde::{Deserialize, Serialize};
use std::path::Path;
@@ -10,9 +10,9 @@ pub struct DbOAuthClient {
pub client_id: String,
pub client_secret_hash: String,
pub client_name: String,
- pub redirect_uris: String, // JSON array
- pub scopes: String, // Space-separated
- pub grant_types: String, // Space-separated
+ pub redirect_uris: String, // JSON array
+ pub scopes: String, // Space-separated
+ pub grant_types: String, // Space-separated
pub response_types: String, // Space-separated
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
@@ -270,7 +270,7 @@ impl Database {
"INSERT INTO oauth_clients
(client_id, client_secret_hash, client_name, redirect_uris, scopes,
grant_types, response_types, created_at, updated_at, is_active)
- VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)"
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
)?;
let id = stmt.insert(params![
@@ -293,7 +293,7 @@ impl Database {
let mut stmt = self.conn.prepare(
"SELECT id, client_id, client_secret_hash, client_name, redirect_uris,
scopes, grant_types, response_types, created_at, updated_at, is_active
- FROM oauth_clients WHERE client_id = ?1 AND is_active = 1"
+ FROM oauth_clients WHERE client_id = ?1 AND is_active = 1",
)?;
let client = stmt.query_row([client_id], |row| {
@@ -307,10 +307,22 @@ impl Database {
grant_types: row.get(6)?,
response_types: row.get(7)?,
created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(8, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 8,
+ "created_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(9, "updated_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 9,
+ "updated_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
is_active: row.get(10)?,
})
@@ -329,7 +341,7 @@ impl Database {
"INSERT INTO auth_codes
(code, client_id, user_id, redirect_uri, scope, expires_at, created_at,
is_used, code_challenge, code_challenge_method)
- VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)"
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
)?;
let id = stmt.insert(params![
@@ -352,7 +364,7 @@ impl Database {
let mut stmt = self.conn.prepare(
"SELECT id, code, client_id, user_id, redirect_uri, scope, expires_at,
created_at, is_used, code_challenge, code_challenge_method
- FROM auth_codes WHERE code = ?1"
+ FROM auth_codes WHERE code = ?1",
)?;
let auth_code = stmt.query_row([code], |row| {
@@ -364,10 +376,22 @@ impl Database {
redirect_uri: row.get(4)?,
scope: row.get(5)?,
expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(6, "expires_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 6,
+ "expires_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(7, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 7,
+ "created_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
is_used: row.get(8)?,
code_challenge: row.get(9)?,
@@ -383,10 +407,8 @@ impl Database {
}
pub fn mark_auth_code_used(&self, code: &str) -> Result<()> {
- self.conn.execute(
- "UPDATE auth_codes SET is_used = 1 WHERE code = ?1",
- [code],
- )?;
+ self.conn
+ .execute("UPDATE auth_codes SET is_used = 1 WHERE code = ?1", [code])?;
Ok(())
}
@@ -395,7 +417,7 @@ impl Database {
let mut stmt = self.conn.prepare(
"INSERT INTO access_tokens
(token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash)
- VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)?;
let id = stmt.insert(params![
@@ -426,10 +448,22 @@ impl Database {
user_id: row.get(3)?,
scope: row.get(4)?,
expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(5)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(5, "expires_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 5,
+ "expires_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(6, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 6,
+ "created_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
is_revoked: row.get(7)?,
token_hash: row.get(8)?,
@@ -455,7 +489,7 @@ impl Database {
pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result<i64> {
let mut stmt = self.conn.prepare(
"INSERT INTO rsa_keys (kid, private_key_pem, public_key_pem, created_at, is_current)
- VALUES (?1, ?2, ?3, ?4, ?5)"
+ VALUES (?1, ?2, ?3, ?4, ?5)",
)?;
let id = stmt.insert(params![
@@ -472,7 +506,7 @@ impl Database {
pub fn get_current_rsa_key(&self) -> Result<Option<DbRsaKey>> {
let mut stmt = self.conn.prepare(
"SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current
- FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1"
+ FROM rsa_keys WHERE is_current = 1 ORDER BY created_at DESC LIMIT 1",
)?;
let key = stmt.query_row([], |row| {
@@ -482,7 +516,13 @@ impl Database {
private_key_pem: row.get(2)?,
public_key_pem: row.get(3)?,
created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 4,
+ "created_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
is_current: row.get(5)?,
})
@@ -498,7 +538,7 @@ impl Database {
pub fn get_all_rsa_keys(&self) -> Result<Vec<DbRsaKey>> {
let mut stmt = self.conn.prepare(
"SELECT id, kid, private_key_pem, public_key_pem, created_at, is_current
- FROM rsa_keys ORDER BY created_at DESC"
+ FROM rsa_keys ORDER BY created_at DESC",
)?;
let keys = stmt.query_map([], |row| {
@@ -508,7 +548,13 @@ impl Database {
private_key_pem: row.get(2)?,
public_key_pem: row.get(3)?,
created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(4)?)
- .map_err(|_| rusqlite::Error::InvalidColumnType(4, "created_at".to_string(), rusqlite::types::Type::Text))?
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 4,
+ "created_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
.with_timezone(&Utc),
is_current: row.get(5)?,
})
@@ -523,14 +569,13 @@ impl Database {
pub fn set_current_rsa_key(&self, kid: &str) -> Result<()> {
// First, unset all current keys
- self.conn.execute("UPDATE rsa_keys SET is_current = 0", [])?;
-
+ self.conn
+ .execute("UPDATE rsa_keys SET is_current = 0", [])?;
+
// Then set the specified key as current
- self.conn.execute(
- "UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1",
- [kid],
- )?;
-
+ self.conn
+ .execute("UPDATE rsa_keys SET is_current = 1 WHERE kid = ?1", [kid])?;
+
Ok(())
}
@@ -539,7 +584,7 @@ impl Database {
let mut stmt = self.conn.prepare(
"INSERT INTO audit_logs
(event_type, client_id, user_id, ip_address, user_agent, details, created_at, success)
- VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)?;
let id = stmt.insert(params![
@@ -557,7 +602,12 @@ impl Database {
}
// Rate Limiting operations
- pub fn increment_rate_limit(&self, identifier: &str, endpoint: &str, window_minutes: i32) -> Result<i32> {
+ pub fn increment_rate_limit(
+ &self,
+ identifier: &str,
+ endpoint: &str,
+ window_minutes: i32,
+ ) -> Result<i32> {
let now = Utc::now();
let window_start = now - chrono::Duration::minutes(window_minutes as i64);
@@ -630,7 +680,7 @@ mod tests {
#[test]
fn test_oauth_client_operations() {
let db = Database::new_in_memory().expect("Failed to create database");
-
+
let client = DbOAuthClient {
id: 0,
client_id: "test_client".to_string(),
@@ -645,10 +695,14 @@ mod tests {
is_active: true,
};
- let id = db.create_oauth_client(&client).expect("Failed to create client");
+ let id = db
+ .create_oauth_client(&client)
+ .expect("Failed to create client");
assert!(id > 0);
- let retrieved = db.get_oauth_client("test_client").expect("Failed to get client");
+ let retrieved = db
+ .get_oauth_client("test_client")
+ .expect("Failed to get client");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().client_name, "Test Client");
}
@@ -671,7 +725,8 @@ mod tests {
updated_at: Utc::now(),
is_active: true,
};
- db.create_oauth_client(&client).expect("Failed to create client");
+ db.create_oauth_client(&client)
+ .expect("Failed to create client");
let auth_code = DbAuthCode {
id: 0,
@@ -687,17 +742,24 @@ mod tests {
code_challenge_method: Some("S256".to_string()),
};
- let id = db.create_auth_code(&auth_code).expect("Failed to create auth code");
+ let id = db
+ .create_auth_code(&auth_code)
+ .expect("Failed to create auth code");
assert!(id > 0);
- let retrieved = db.get_auth_code("test_code_123").expect("Failed to get auth code");
+ let retrieved = db
+ .get_auth_code("test_code_123")
+ .expect("Failed to get auth code");
assert!(retrieved.is_some());
let code = retrieved.unwrap();
assert_eq!(code.client_id, "test_client");
assert_eq!(code.is_used, false);
- db.mark_auth_code_used("test_code_123").expect("Failed to mark code as used");
- let updated = db.get_auth_code("test_code_123").expect("Failed to get auth code");
+ db.mark_auth_code_used("test_code_123")
+ .expect("Failed to mark code as used");
+ let updated = db
+ .get_auth_code("test_code_123")
+ .expect("Failed to get auth code");
assert_eq!(updated.unwrap().is_used, true);
}
-} \ No newline at end of file
+}
diff --git a/src/keys.rs b/src/keys.rs
index 16b943c..675eb61 100644
--- a/src/keys.rs
+++ b/src/keys.rs
@@ -1,6 +1,9 @@
+use crate::database::{Database, DbRsaKey};
+use anyhow::Result;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
+use chrono::Utc;
use jsonwebtoken::{DecodingKey, EncodingKey};
-use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey, DecodePrivateKey, DecodePublicKey};
+use rsa::pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey};
use rsa::traits::PublicKeyParts;
use rsa::{RsaPrivateKey, RsaPublicKey};
use serde::Serialize;
@@ -8,9 +11,6 @@ use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use uuid::Uuid;
-use chrono::Utc;
-use crate::database::{Database, DbRsaKey};
-use anyhow::Result;
#[derive(Clone)]
pub struct KeyPair {
@@ -56,28 +56,28 @@ impl KeyManager {
// Load existing keys from database
manager.load_keys_from_db()?;
-
+
// If no keys exist, generate the first one
if manager.keys.is_empty() {
manager.generate_new_key()?;
}
-
+
Ok(manager)
}
-
+
fn load_keys_from_db(&mut self) -> Result<()> {
let db_keys = {
let db = self.database.lock().unwrap();
db.get_all_rsa_keys()?
};
-
+
for db_key in db_keys {
let private_key = RsaPrivateKey::from_pkcs8_pem(&db_key.private_key_pem)?;
let public_key = RsaPublicKey::from_public_key_pem(&db_key.public_key_pem)?;
-
+
let encoding_key = EncodingKey::from_rsa_pem(db_key.private_key_pem.as_bytes())?;
let decoding_key = DecodingKey::from_rsa_pem(db_key.public_key_pem.as_bytes())?;
-
+
let key_pair = KeyPair {
kid: db_key.kid.clone(),
private_key,
@@ -86,14 +86,14 @@ impl KeyManager {
encoding_key,
decoding_key,
};
-
+
self.keys.insert(db_key.kid.clone(), key_pair);
-
+
if db_key.is_current {
self.current_key_id = Some(db_key.kid);
}
}
-
+
Ok(())
}
@@ -121,7 +121,7 @@ impl KeyManager {
created_at: now,
is_current: true, // This will be the new current key
};
-
+
{
let db = self.database.lock().unwrap();
db.create_rsa_key(&db_key)?;
diff --git a/src/main.rs b/src/main.rs
index f5951e0..0612bfa 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,31 +1,36 @@
-use sts::http::Server;
-use sts::Config;
use std::thread;
use std::time::Duration;
+use sts::Config;
+use sts::http::Server;
fn main() {
let config = Config::from_env();
let server = Server::new(config.clone()).expect("Failed to create server");
-
+
// Start cleanup task in background
let cleanup_config = config.clone();
thread::spawn(move || {
loop {
- thread::sleep(Duration::from_secs(cleanup_config.cleanup_interval_hours as u64 * 3600));
+ thread::sleep(Duration::from_secs(
+ cleanup_config.cleanup_interval_hours as u64 * 3600,
+ ));
// Note: In the current implementation, we don't have direct access to the OAuth server
// from here to call cleanup_expired_data(). In a production implementation,
// you'd want to structure this differently or use a background job queue.
}
});
-
+
println!("Starting OAuth2 STS server...");
println!("Configuration:");
println!(" Bind Address: {}", config.bind_addr);
println!(" Issuer URL: {}", config.issuer_url);
println!(" Database: {}", config.database_path);
- println!(" Rate Limit: {} requests/minute", config.rate_limit_requests_per_minute);
+ println!(
+ " Rate Limit: {} requests/minute",
+ config.rate_limit_requests_per_minute
+ );
println!(" Audit Logging: {}", config.enable_audit_logging);
-
+
server.start();
}
@@ -102,15 +107,15 @@ mod disabled_tests {
fn test_jwks_endpoint() {
let config = sts::Config::from_env();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
-
+
let jwks_json = oauth_server.get_jwks();
assert!(!jwks_json.is_empty());
-
+
// Parse the JSON to verify structure
let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
assert!(jwks["keys"].is_array());
assert!(jwks["keys"].as_array().unwrap().len() > 0);
-
+
let key = &jwks["keys"][0];
assert_eq!(key["kty"], "RSA");
assert_eq!(key["use"], "sig");
@@ -133,7 +138,7 @@ mod disabled_tests {
auth_params.insert("state".to_string(), "test_state".to_string());
let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
-
+
// Extract the authorization code from redirect URL
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
@@ -150,17 +155,17 @@ mod disabled_tests {
token_params.insert("client_secret".to_string(), "test_secret".to_string());
let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
-
+
// Parse token response
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
-
+
assert_eq!(token_response["token_type"], "Bearer");
assert_eq!(token_response["expires_in"], 3600);
assert!(token_response["access_token"].is_string());
-
+
let access_token = token_response["access_token"].as_str().unwrap();
-
+
// Step 3: Verify the JWT token has RSA signature and key ID
let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
@@ -168,7 +173,7 @@ mod disabled_tests {
assert!(!header.kid.as_ref().unwrap().is_empty());
}
- #[test]
+ #[test]
fn test_token_validation_with_jwks() {
let config = sts::Config::from_env();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
@@ -201,11 +206,11 @@ mod disabled_tests {
// Get the JWKS
let jwks_json = oauth_server.get_jwks();
let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
-
+
// Decode the token header to get the key ID
let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
let kid = header.kid.as_ref().expect("No key ID in token");
-
+
// Find the matching key in JWKS
let matching_key = jwks["keys"]
.as_array()
@@ -213,7 +218,7 @@ mod disabled_tests {
.iter()
.find(|key| key["kid"] == *kid)
.expect("Key ID not found in JWKS");
-
+
assert_eq!(matching_key["kty"], "RSA");
assert_eq!(matching_key["alg"], "RS256");
}
@@ -255,17 +260,17 @@ mod disabled_tests {
&jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing
&jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256)
);
-
+
// Since we can't validate with a dummy key, we'll just verify the structure
// by decoding the payload manually
let parts: Vec<&str> = access_token.split('.').collect();
assert_eq!(parts.len(), 3); // header.payload.signature
-
+
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.expect("Failed to decode payload");
let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON");
-
+
assert!(claims["sub"].is_string());
assert!(claims["iss"].is_string());
assert!(claims["aud"].is_string());
diff --git a/src/migrations.rs b/src/migrations.rs
index c7cd6bf..5076a9e 100644
--- a/src/migrations.rs
+++ b/src/migrations.rs
@@ -43,13 +43,16 @@ impl<'a> MigrationRunner<'a> {
// Get current migration version
let current_version = self.get_current_version()?;
-
+
println!("Current database version: {}", current_version);
// Run pending migrations
for migration in MIGRATIONS {
if migration.version > current_version {
- println!("Running migration {}: {}", migration.version, migration.name);
+ println!(
+ "Running migration {}: {}",
+ migration.version, migration.name
+ );
self.run_migration(migration)?;
}
}
@@ -70,7 +73,7 @@ impl<'a> MigrationRunner<'a> {
fn run_migration(&self, migration: &Migration) -> Result<()> {
// Execute the migration SQL
self.conn.execute_batch(migration.sql)?;
-
+
// Record the migration as applied
self.conn.execute(
"INSERT INTO schema_migrations (version, name, applied_at) VALUES (?1, ?2, ?3)",
@@ -86,14 +89,14 @@ impl<'a> MigrationRunner<'a> {
pub fn rollback_to_version(&self, target_version: i32) -> Result<()> {
println!("Rolling back to version {}", target_version);
-
+
// This is a simplified rollback - in practice you'd need down migrations
// For now, just remove migration records
self.conn.execute(
"DELETE FROM schema_migrations WHERE version > ?1",
[target_version],
)?;
-
+
println!("Rollback completed (Note: This doesn't actually undo schema changes)");
Ok(())
}
@@ -101,11 +104,11 @@ impl<'a> MigrationRunner<'a> {
pub fn show_migration_status(&self) -> Result<()> {
println!("Migration Status:");
println!("================");
-
- let mut stmt = self.conn.prepare(
- "SELECT version, name, applied_at FROM schema_migrations ORDER BY version"
- )?;
-
+
+ let mut stmt = self
+ .conn
+ .prepare("SELECT version, name, applied_at FROM schema_migrations ORDER BY version")?;
+
let migrations = stmt.query_map([], |row| {
Ok((
row.get::<_, i32>(0)?,
@@ -116,14 +119,20 @@ impl<'a> MigrationRunner<'a> {
for migration in migrations {
let (version, name, applied_at) = migration?;
- println!("✅ Migration {}: {} (applied: {})", version, name, applied_at);
+ println!(
+ "✅ Migration {}: {} (applied: {})",
+ version, name, applied_at
+ );
}
// Show pending migrations
let current_version = self.get_current_version()?;
for migration in MIGRATIONS {
if migration.version > current_version {
- println!("⏳ Migration {}: {} (pending)", migration.version, migration.name);
+ println!(
+ "⏳ Migration {}: {} (pending)",
+ migration.version, migration.name
+ );
}
}
@@ -139,14 +148,17 @@ mod tests {
fn test_migration_runner() {
let conn = Connection::open_in_memory().unwrap();
let runner = MigrationRunner::new(&conn);
-
+
// Should start with version 0
assert_eq!(runner.get_current_version().unwrap(), 0);
-
+
// Run migrations
runner.run_migrations().unwrap();
-
+
// Should now be at latest version
- assert_eq!(runner.get_current_version().unwrap(), MIGRATIONS.len() as i32);
+ assert_eq!(
+ runner.get_current_version().unwrap(),
+ MIGRATIONS.len() as i32
+ );
}
-} \ No newline at end of file
+}