summaryrefslogtreecommitdiff
path: root/src/clients.rs
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-09 19:39:28 -0600
committermo khan <mo@mokhan.ca>2025-06-09 19:39:28 -0600
commit1b280d0b9be71daf9dffa0f4fa33559eda91946f (patch)
treecc301717deabe901ce82ce469fef6be84c1b2bed /src/clients.rs
parent50731865b1a22ab250e988aed2ea2bb8a18f9338 (diff)
Extract a registered client type
Diffstat (limited to 'src/clients.rs')
-rw-r--r--src/clients.rs318
1 files changed, 318 insertions, 0 deletions
diff --git a/src/clients.rs b/src/clients.rs
new file mode 100644
index 0000000..38d7450
--- /dev/null
+++ b/src/clients.rs
@@ -0,0 +1,318 @@
+use base64::Engine;
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+use std::collections::HashMap;
+use uuid::Uuid;
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct RegisteredClient {
+ pub client_id: String,
+ pub client_secret_hash: String,
+ pub redirect_uris: Vec<String>,
+ pub client_name: String,
+ pub scopes: Vec<String>,
+ pub grant_types: Vec<String>,
+ pub response_types: Vec<String>,
+ pub created_at: u64,
+}
+
+#[derive(Debug, Clone)]
+pub struct ClientCredentials {
+ pub client_id: String,
+ pub client_secret: String,
+}
+
+pub struct ClientManager {
+ clients: HashMap<String, RegisteredClient>,
+}
+
+impl ClientManager {
+ pub fn new() -> Self {
+ let mut manager = Self {
+ clients: HashMap::new(),
+ };
+
+ // Register a default test client for development
+ manager.register_client(
+ "test_client".to_string(),
+ "test_secret".to_string(),
+ vec!["http://localhost:3000/callback".to_string()],
+ "Test Client".to_string(),
+ vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ ).ok();
+
+ manager
+ }
+
+ pub fn register_client(
+ &mut self,
+ client_id: String,
+ client_secret: String,
+ redirect_uris: Vec<String>,
+ client_name: String,
+ scopes: Vec<String>,
+ ) -> Result<RegisteredClient, String> {
+ // Check if client_id already exists
+ if self.clients.contains_key(&client_id) {
+ return Err("Client ID already exists".to_string());
+ }
+
+ // Validate redirect URIs
+ for uri in &redirect_uris {
+ if !Self::is_valid_redirect_uri(uri) {
+ return Err(format!("Invalid redirect URI: {}", uri));
+ }
+ }
+
+ // Hash the client secret
+ let client_secret_hash = Self::hash_secret(&client_secret);
+
+ let client = RegisteredClient {
+ client_id: client_id.clone(),
+ client_secret_hash,
+ redirect_uris,
+ client_name,
+ scopes,
+ grant_types: vec!["authorization_code".to_string()],
+ response_types: vec!["code".to_string()],
+ created_at: std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs(),
+ };
+
+ self.clients.insert(client_id, client.clone());
+ Ok(client)
+ }
+
+ pub fn get_client(&self, client_id: &str) -> Option<&RegisteredClient> {
+ self.clients.get(client_id)
+ }
+
+ pub fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool {
+ if let Some(client) = self.get_client(client_id) {
+ let provided_hash = Self::hash_secret(client_secret);
+ // Use constant-time comparison to prevent timing attacks
+ self.constant_time_eq(&client.client_secret_hash, &provided_hash)
+ } else {
+ // Still perform hashing even for non-existent clients to prevent timing attacks
+ Self::hash_secret(client_secret);
+ false
+ }
+ }
+
+ pub fn is_redirect_uri_valid(&self, client_id: &str, redirect_uri: &str) -> bool {
+ if let Some(client) = self.get_client(client_id) {
+ client.redirect_uris.contains(&redirect_uri.to_string())
+ } else {
+ false
+ }
+ }
+
+ pub fn is_scope_valid(&self, client_id: &str, requested_scopes: &Option<String>) -> bool {
+ if let Some(client) = self.get_client(client_id) {
+ if let Some(scopes_str) = requested_scopes {
+ let requested: Vec<&str> = scopes_str.split_whitespace().collect();
+ requested.iter().all(|scope| client.scopes.contains(&scope.to_string()))
+ } else {
+ true // No scopes requested is valid
+ }
+ } else {
+ false
+ }
+ }
+
+ fn hash_secret(secret: &str) -> String {
+ let mut hasher = Sha256::new();
+ hasher.update(secret.as_bytes());
+ format!("{:x}", hasher.finalize())
+ }
+
+ fn is_valid_redirect_uri(uri: &str) -> bool {
+ // Basic validation - in production, this should be more comprehensive
+ uri.starts_with("http://") || uri.starts_with("https://")
+ }
+
+ // Constant-time string comparison to prevent timing attacks
+ fn constant_time_eq(&self, a: &str, b: &str) -> bool {
+ if a.len() != b.len() {
+ return false;
+ }
+
+ let mut result = 0u8;
+ for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
+ result |= byte_a ^ byte_b;
+ }
+ result == 0
+ }
+
+ pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec<String>) -> Result<ClientCredentials, String> {
+ let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", ""));
+ let client_secret = Uuid::new_v4().to_string();
+
+ self.register_client(
+ client_id.clone(),
+ client_secret.clone(),
+ redirect_uris,
+ client_name,
+ vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ )?;
+
+ Ok(ClientCredentials {
+ client_id,
+ client_secret,
+ })
+ }
+
+ pub fn list_clients(&self) -> Vec<&RegisteredClient> {
+ self.clients.values().collect()
+ }
+}
+
+// HTTP Basic Auth parsing helper
+pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> {
+ if !auth_header.starts_with("Basic ") {
+ return None;
+ }
+
+ let encoded = &auth_header[6..];
+ let decoded = base64::engine::general_purpose::STANDARD
+ .decode(encoded)
+ .ok()?;
+
+ let credentials = String::from_utf8(decoded).ok()?;
+ let mut parts = credentials.splitn(2, ':');
+
+ let username = parts.next()?.to_string();
+ let password = parts.next()?.to_string();
+
+ Some((username, password))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_client_registration() {
+ let mut manager = ClientManager::new();
+
+ let result = manager.register_client(
+ "new_client".to_string(),
+ "secret123".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "Test App".to_string(),
+ vec!["openid".to_string()],
+ );
+
+ assert!(result.is_ok());
+ let client = result.unwrap();
+ assert_eq!(client.client_id, "new_client");
+ assert_eq!(client.client_name, "Test App");
+ }
+
+ #[test]
+ fn test_duplicate_client_id() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "duplicate".to_string(),
+ "secret1".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "App 1".to_string(),
+ vec!["openid".to_string()],
+ ).unwrap();
+
+ let result = manager.register_client(
+ "duplicate".to_string(),
+ "secret2".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "App 2".to_string(),
+ vec!["openid".to_string()],
+ );
+
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_client_authentication() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "auth_client".to_string(),
+ "correct_secret".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "Auth Test".to_string(),
+ vec!["openid".to_string()],
+ ).unwrap();
+
+ assert!(manager.authenticate_client("auth_client", "correct_secret"));
+ assert!(!manager.authenticate_client("auth_client", "wrong_secret"));
+ assert!(!manager.authenticate_client("nonexistent", "any_secret"));
+ }
+
+ #[test]
+ fn test_redirect_uri_validation() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "uri_client".to_string(),
+ "secret".to_string(),
+ vec!["https://app.com/callback".to_string(), "http://localhost:3000/callback".to_string()],
+ "URI Test".to_string(),
+ vec!["openid".to_string()],
+ ).unwrap();
+
+ assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback"));
+ assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback"));
+ assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback"));
+ assert!(!manager.is_redirect_uri_valid("nonexistent", "https://app.com/callback"));
+ }
+
+ #[test]
+ fn test_scope_validation() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "scope_client".to_string(),
+ "secret".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "Scope Test".to_string(),
+ vec!["openid".to_string(), "profile".to_string()],
+ ).unwrap();
+
+ assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string())));
+ assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string())));
+ assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string())));
+ assert!(manager.is_scope_valid("scope_client", &None));
+ }
+
+ #[test]
+ fn test_basic_auth_parsing() {
+ let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret
+ let result = parse_basic_auth(auth_header);
+
+ assert!(result.is_some());
+ let (username, password) = result.unwrap();
+ assert_eq!(username, "test_client");
+ assert_eq!(password, "test_secret");
+ }
+
+ #[test]
+ fn test_generate_client_credentials() {
+ let mut manager = ClientManager::new();
+
+ let result = manager.generate_client_credentials(
+ "Generated App".to_string(),
+ vec!["https://generated.com/callback".to_string()],
+ );
+
+ assert!(result.is_ok());
+ let credentials = result.unwrap();
+ assert!(credentials.client_id.starts_with("client_"));
+ assert!(!credentials.client_secret.is_empty());
+
+ // Verify the client was actually registered
+ assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret));
+ }
+} \ No newline at end of file