summaryrefslogtreecommitdiff
path: root/src/services
diff options
context:
space:
mode:
Diffstat (limited to 'src/services')
-rw-r--r--src/services/implementations.rs217
-rw-r--r--src/services/mod.rs49
2 files changed, 266 insertions, 0 deletions
diff --git a/src/services/implementations.rs b/src/services/implementations.rs
new file mode 100644
index 0000000..ff03165
--- /dev/null
+++ b/src/services/implementations.rs
@@ -0,0 +1,217 @@
+use super::*;
+use crate::clients::parse_basic_auth;
+use crate::config::Config;
+use crate::database::DbAuditLog;
+use crate::keys::KeyManager;
+use crate::oauth::types::Claims;
+use crate::repositories::{AuditRepository, ClientRepository, RateRepository};
+use anyhow::{Result, anyhow};
+use chrono::Utc;
+use jsonwebtoken::{Algorithm, Header, encode};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+use std::time::{SystemTime, UNIX_EPOCH};
+use uuid::Uuid;
+
+/// Default implementation of ClientAuthenticator
+pub struct DefaultClientAuthenticator {
+ client_repository: Arc<dyn ClientRepository>,
+}
+
+impl DefaultClientAuthenticator {
+ pub fn new(client_repository: Arc<dyn ClientRepository>) -> Self {
+ Self { client_repository }
+ }
+
+ fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool {
+ match self.client_repository.get_client(client_id) {
+ Ok(Some(client)) => {
+ // Use constant-time comparison to prevent timing attacks
+ use subtle::ConstantTimeEq;
+ let expected_hash = client.client_secret_hash.as_bytes();
+ let provided_hash = self.hash_client_secret(client_secret);
+ expected_hash.ct_eq(provided_hash.as_bytes()).into()
+ }
+ _ => false,
+ }
+ }
+
+ fn hash_client_secret(&self, secret: &str) -> String {
+ use sha2::{Digest, Sha256};
+ let mut hasher = Sha256::new();
+ hasher.update(secret.as_bytes());
+ format!("{:x}", hasher.finalize())
+ }
+}
+
+impl ClientAuthenticator for DefaultClientAuthenticator {
+ fn authenticate(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(String, String), String> {
+ if let Some(auth_header) = auth_header {
+ // HTTP Basic Authentication (preferred method)
+ parse_basic_auth(auth_header).ok_or_else(|| "Invalid Authorization header".to_string())
+ } else {
+ // Form-based authentication (fallback)
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| "Missing client_id".to_string())?;
+ let client_secret = params
+ .get("client_secret")
+ .ok_or_else(|| "Missing client_secret".to_string())?;
+
+ if !self.authenticate_client(client_id, client_secret) {
+ return Err("Invalid client credentials".to_string());
+ }
+
+ Ok((client_id.clone(), client_secret.clone()))
+ }
+ }
+}
+
+/// Default implementation of RateLimiter
+pub struct DefaultRateLimiter {
+ rate_repository: Arc<dyn RateRepository>,
+ config: Config,
+}
+
+impl DefaultRateLimiter {
+ pub fn new(rate_repository: Arc<dyn RateRepository>, config: Config) -> Self {
+ Self {
+ rate_repository,
+ config,
+ }
+ }
+}
+
+impl RateLimiter for DefaultRateLimiter {
+ fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> {
+ let count = self
+ .rate_repository
+ .increment_rate_limit(identifier, endpoint, 1)?;
+
+ if count > self.config.rate_limit_requests_per_minute as i32 {
+ return Err(anyhow!("Rate limit exceeded"));
+ }
+
+ Ok(())
+ }
+}
+
+/// Default implementation of AuditLogger
+pub struct DefaultAuditLogger {
+ audit_repository: Arc<dyn AuditRepository>,
+ config: Config,
+}
+
+impl DefaultAuditLogger {
+ pub fn new(audit_repository: Arc<dyn AuditRepository>, config: Config) -> Self {
+ Self {
+ audit_repository,
+ config,
+ }
+ }
+}
+
+impl AuditLogger for DefaultAuditLogger {
+ fn log_event(
+ &self,
+ event_type: &str,
+ client_id: Option<&str>,
+ user_id: Option<&str>,
+ ip_address: Option<&str>,
+ success: bool,
+ details: Option<&str>,
+ ) -> Result<()> {
+ if !self.config.enable_audit_logging {
+ return Ok(());
+ }
+
+ let log = DbAuditLog {
+ id: 0,
+ event_type: event_type.to_string(),
+ client_id: client_id.map(|s| s.to_string()),
+ user_id: user_id.map(|s| s.to_string()),
+ ip_address: ip_address.map(|s| s.to_string()),
+ user_agent: None, // Could be passed in from HTTP layer
+ details: details.map(|s| s.to_string()),
+ created_at: Utc::now(),
+ success,
+ };
+
+ self.audit_repository.create_audit_log(&log)?;
+ Ok(())
+ }
+}
+
+/// Default implementation of TokenGenerator
+pub struct DefaultTokenGenerator {
+ key_manager: Arc<Mutex<KeyManager>>,
+ config: Config,
+}
+
+impl DefaultTokenGenerator {
+ pub fn new(key_manager: Arc<Mutex<KeyManager>>, config: Config) -> Self {
+ Self {
+ key_manager,
+ config,
+ }
+ }
+}
+
+impl TokenGenerator for DefaultTokenGenerator {
+ fn generate_access_token(
+ &self,
+ user_id: &str,
+ client_id: &str,
+ scope: &Option<String>,
+ token_id: &str,
+ ) -> Result<String, String> {
+ let mut key_manager = self.key_manager.lock().unwrap();
+
+ // Check if we need to rotate keys
+ if key_manager.should_rotate() {
+ if let Err(_) = key_manager.rotate_keys() {
+ return Err("Key rotation failed".to_string());
+ }
+ }
+
+ let current_key = key_manager
+ .get_current_key()
+ .ok_or_else(|| "No signing key available".to_string())?;
+
+ let now = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let claims = Claims {
+ sub: user_id.to_string(),
+ iss: self.config.issuer_url.clone(),
+ aud: client_id.to_string(),
+ exp: now + 3600,
+ iat: now,
+ scope: scope.clone(),
+ jti: Some(token_id.to_string()),
+ };
+
+ let mut header = Header::new(Algorithm::RS256);
+ header.kid = Some(current_key.kid.clone());
+
+ encode(&header, &claims, &current_key.encoding_key)
+ .map_err(|_| "Failed to generate token".to_string())
+ }
+
+ fn generate_refresh_token(
+ &self,
+ _client_id: &str,
+ _user_id: &str,
+ _scope: &Option<String>,
+ ) -> Result<String, String> {
+ // For now, return a simple UUID-based refresh token
+ // In production, this should be a proper JWT or encrypted token
+ Ok(Uuid::new_v4().to_string())
+ }
+}
diff --git a/src/services/mod.rs b/src/services/mod.rs
new file mode 100644
index 0000000..26d74e3
--- /dev/null
+++ b/src/services/mod.rs
@@ -0,0 +1,49 @@
+use anyhow::Result;
+use std::collections::HashMap;
+
+/// Service trait for client authentication
+pub trait ClientAuthenticator: Send + Sync {
+ fn authenticate(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(String, String), String>; // Returns (client_id, client_secret)
+}
+
+/// Service trait for rate limiting
+pub trait RateLimiter: Send + Sync {
+ fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()>;
+}
+
+/// Service trait for audit logging
+pub trait AuditLogger: Send + Sync {
+ fn log_event(
+ &self,
+ event_type: &str,
+ client_id: Option<&str>,
+ user_id: Option<&str>,
+ ip_address: Option<&str>,
+ success: bool,
+ details: Option<&str>,
+ ) -> Result<()>;
+}
+
+/// Service trait for token generation
+pub trait TokenGenerator: Send + Sync {
+ fn generate_access_token(
+ &self,
+ user_id: &str,
+ client_id: &str,
+ scope: &Option<String>,
+ token_id: &str,
+ ) -> Result<String, String>;
+
+ fn generate_refresh_token(
+ &self,
+ client_id: &str,
+ user_id: &str,
+ scope: &Option<String>,
+ ) -> Result<String, String>;
+}
+
+pub mod implementations;