summaryrefslogtreecommitdiff
path: root/src/clients.rs
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 15:15:41 -0600
committermo khan <mo@mokhan.ca>2025-06-11 15:15:41 -0600
commitaea6bd6ec7d7e70a67723edf6327df4a9cc65d89 (patch)
tree80fcb6cbda7baa5ed15cf044d7583acb2438c4d2 /src/clients.rs
parent4435ee26b79648e92d0f172e42f9e6629e955505 (diff)
chore: run rustfmt again
Diffstat (limited to 'src/clients.rs')
-rw-r--r--src/clients.rs112
1 files changed, 69 insertions, 43 deletions
diff --git a/src/clients.rs b/src/clients.rs
index bc9aea5..7941d8c 100644
--- a/src/clients.rs
+++ b/src/clients.rs
@@ -1,12 +1,12 @@
+use crate::database::{Database, DbOAuthClient};
+use anyhow::Result;
use base64::Engine;
+use chrono::Utc;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use uuid::Uuid;
-use chrono::Utc;
-use crate::database::{Database, DbOAuthClient};
-use anyhow::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthClient {
@@ -37,10 +37,10 @@ impl ClientManager {
clients: HashMap::new(),
database: database.clone(),
};
-
+
// Load existing clients from database into cache
manager.load_clients_from_db()?;
-
+
// Register a default test client for development if it doesn't exist
if manager.get_client_from_db("test_client")?.is_none() {
let _ = manager.register_client(
@@ -48,13 +48,17 @@ impl ClientManager {
"test_secret".to_string(),
vec!["http://localhost:3000/callback".to_string()],
"Test Client".to_string(),
- vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ vec![
+ "openid".to_string(),
+ "profile".to_string(),
+ "email".to_string(),
+ ],
); // Ignore errors if client already exists
}
-
+
Ok(manager)
}
-
+
fn load_clients_from_db(&mut self) -> Result<()> {
// This is a simplified version - in practice you'd want to load all clients
// For now we'll load on-demand
@@ -129,40 +133,52 @@ impl ClientManager {
if let Some(client) = self.clients.get(client_id) {
return Some(client);
}
-
+
// If not in cache, try to load from database
// For thread safety, we can't mutate self here, so we'll return None
// In a real implementation, you'd want a more sophisticated caching strategy
None
}
-
+
pub fn get_client_from_db(&mut self, client_id: &str) -> Result<Option<OAuthClient>> {
// Check cache first
if let Some(client) = self.clients.get(client_id) {
return Ok(Some(client.clone()));
}
-
+
// Load from database
let db_client = {
let db = self.database.lock().unwrap();
db.get_oauth_client(client_id)?
};
-
+
if let Some(db_client) = db_client {
let redirect_uris: Vec<String> = serde_json::from_str(&db_client.redirect_uris)?;
- let scopes: Vec<String> = db_client.scopes.split_whitespace().map(|s| s.to_string()).collect();
-
+ let scopes: Vec<String> = db_client
+ .scopes
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect();
+
let client = OAuthClient {
client_id: db_client.client_id.clone(),
client_secret_hash: db_client.client_secret_hash,
redirect_uris,
client_name: db_client.client_name,
scopes,
- grant_types: db_client.grant_types.split_whitespace().map(|s| s.to_string()).collect(),
- response_types: db_client.response_types.split_whitespace().map(|s| s.to_string()).collect(),
+ grant_types: db_client
+ .grant_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect(),
+ response_types: db_client
+ .response_types
+ .split_whitespace()
+ .map(|s| s.to_string())
+ .collect(),
created_at: db_client.created_at.timestamp() as u64,
};
-
+
// Cache it
self.clients.insert(db_client.client_id, client.clone());
Ok(Some(client))
@@ -181,7 +197,7 @@ impl ClientManager {
return false;
}
};
-
+
let provided_hash = Self::hash_secret(client_secret);
// Use constant-time comparison to prevent timing attacks
self.constant_time_eq(&client.client_secret_hash, &provided_hash)
@@ -199,7 +215,9 @@ impl ClientManager {
if let Ok(Some(client)) = self.get_client_from_db(client_id) {
if let Some(scopes_str) = requested_scopes {
let requested: Vec<&str> = scopes_str.split_whitespace().collect();
- requested.iter().all(|scope| client.scopes.contains(&scope.to_string()))
+ requested
+ .iter()
+ .all(|scope| client.scopes.contains(&scope.to_string()))
} else {
true // No scopes requested is valid
}
@@ -224,7 +242,7 @@ impl ClientManager {
if a.len() != b.len() {
return false;
}
-
+
let mut result = 0u8;
for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
result |= byte_a ^ byte_b;
@@ -232,16 +250,24 @@ impl ClientManager {
result == 0
}
- pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec<String>) -> Result<ClientCredentials> {
+ pub fn generate_client_credentials(
+ &mut self,
+ client_name: String,
+ redirect_uris: Vec<String>,
+ ) -> Result<ClientCredentials> {
let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", ""));
let client_secret = Uuid::new_v4().to_string();
-
+
self.register_client(
client_id.clone(),
client_secret.clone(),
redirect_uris,
client_name,
- vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ vec![
+ "openid".to_string(),
+ "profile".to_string(),
+ "email".to_string(),
+ ],
)?;
Ok(ClientCredentials {
@@ -253,7 +279,7 @@ impl ClientManager {
pub fn list_clients(&self) -> Vec<&OAuthClient> {
self.clients.values().collect()
}
-
+
pub fn list_all_clients_from_db(&self) -> Result<Vec<DbOAuthClient>> {
// This would require a new database method - for now return empty
Ok(vec![])
@@ -270,13 +296,13 @@ pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> {
let decoded = base64::engine::general_purpose::STANDARD
.decode(encoded)
.ok()?;
-
+
let credentials = String::from_utf8(decoded).ok()?;
let mut parts = credentials.splitn(2, ':');
-
+
let username = parts.next()?.to_string();
let password = parts.next()?.to_string();
-
+
Some((username, password))
}
@@ -288,7 +314,7 @@ mod disabled_tests {
#[test]
fn test_client_registration() {
let mut manager = ClientManager::new();
-
+
let result = manager.register_client(
"new_client".to_string(),
"secret123".to_string(),
@@ -296,7 +322,7 @@ mod disabled_tests {
"Test App".to_string(),
vec!["openid".to_string()],
);
-
+
assert!(result.is_ok());
let client = result.unwrap();
assert_eq!(client.client_id, "new_client");
@@ -306,7 +332,7 @@ mod disabled_tests {
#[test]
fn test_duplicate_client_id() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"duplicate".to_string(),
"secret1".to_string(),
@@ -314,7 +340,7 @@ mod disabled_tests {
"App 1".to_string(),
vec!["openid".to_string()],
).unwrap();
-
+
let result = manager.register_client(
"duplicate".to_string(),
"secret2".to_string(),
@@ -322,14 +348,14 @@ mod disabled_tests {
"App 2".to_string(),
vec!["openid".to_string()],
);
-
+
assert!(result.is_err());
}
#[test]
fn test_client_authentication() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"auth_client".to_string(),
"correct_secret".to_string(),
@@ -337,7 +363,7 @@ mod disabled_tests {
"Auth Test".to_string(),
vec!["openid".to_string()],
).unwrap();
-
+
assert!(manager.authenticate_client("auth_client", "correct_secret"));
assert!(!manager.authenticate_client("auth_client", "wrong_secret"));
assert!(!manager.authenticate_client("nonexistent", "any_secret"));
@@ -346,7 +372,7 @@ mod disabled_tests {
#[test]
fn test_redirect_uri_validation() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"uri_client".to_string(),
"secret".to_string(),
@@ -354,7 +380,7 @@ mod disabled_tests {
"URI Test".to_string(),
vec!["openid".to_string()],
).unwrap();
-
+
assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback"));
assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback"));
assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback"));
@@ -364,7 +390,7 @@ mod disabled_tests {
#[test]
fn test_scope_validation() {
let mut manager = ClientManager::new();
-
+
manager.register_client(
"scope_client".to_string(),
"secret".to_string(),
@@ -372,7 +398,7 @@ mod disabled_tests {
"Scope Test".to_string(),
vec!["openid".to_string(), "profile".to_string()],
).unwrap();
-
+
assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string())));
assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string())));
assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string())));
@@ -383,7 +409,7 @@ mod disabled_tests {
fn test_basic_auth_parsing() {
let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret
let result = parse_basic_auth(auth_header);
-
+
assert!(result.is_some());
let (username, password) = result.unwrap();
assert_eq!(username, "test_client");
@@ -393,19 +419,19 @@ mod disabled_tests {
#[test]
fn test_generate_client_credentials() {
let mut manager = ClientManager::new();
-
+
let result = manager.generate_client_credentials(
"Generated App".to_string(),
vec!["https://generated.com/callback".to_string()],
);
-
+
assert!(result.is_ok());
let credentials = result.unwrap();
assert!(credentials.client_id.starts_with("client_"));
assert!(!credentials.client_secret.is_empty());
-
+
// Verify the client was actually registered
assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret));
}
}
-*/ \ No newline at end of file
+*/