summaryrefslogtreecommitdiff
path: root/src/container.rs
blob: 3a4b13ed095ef0b6858802ceaa4404b552543858 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
use crate::config::Config;
use crate::database::Database;
use crate::keys::KeyManager;
use crate::repositories::*;
use crate::services::implementations::*;
use crate::services::*;
use anyhow::Result;
use std::sync::{Arc, Mutex};

/// Dependency injection container for all services and repositories
pub struct ServiceContainer {
    // Repositories
    pub client_repository: Arc<dyn ClientRepository>,
    pub auth_code_repository: Arc<dyn AuthCodeRepository>,
    pub token_repository: Arc<dyn TokenRepository>,
    pub audit_repository: Arc<dyn AuditRepository>,
    pub rate_repository: Arc<dyn RateRepository>,

    // Services
    pub client_authenticator: Arc<dyn ClientAuthenticator>,
    pub rate_limiter: Arc<dyn RateLimiter>,
    pub audit_logger: Arc<dyn AuditLogger>,
    pub token_generator: Arc<dyn TokenGenerator>,

    // Core components
    pub key_manager: Arc<Mutex<KeyManager>>,
    pub config: Config,
}

impl ServiceContainer {
    pub fn new(config: Config, database: Arc<Mutex<Database>>) -> Result<Self> {
        // Create repositories
        let client_repository: Arc<dyn ClientRepository> =
            Arc::new(SqliteClientRepository::new(database.clone()));
        let auth_code_repository: Arc<dyn AuthCodeRepository> =
            Arc::new(SqliteAuthCodeRepository::new(database.clone()));
        let token_repository: Arc<dyn TokenRepository> =
            Arc::new(SqliteTokenRepository::new(database.clone()));
        let audit_repository: Arc<dyn AuditRepository> =
            Arc::new(SqliteAuditRepository::new(database.clone()));
        let rate_repository: Arc<dyn RateRepository> =
            Arc::new(SqliteRateRepository::new(database.clone()));

        // Create key manager
        let key_manager = Arc::new(Mutex::new(KeyManager::new(database.clone())?));

        // Create services
        let client_authenticator: Arc<dyn ClientAuthenticator> =
            Arc::new(DefaultClientAuthenticator::new(client_repository.clone()));
        let rate_limiter: Arc<dyn RateLimiter> = Arc::new(DefaultRateLimiter::new(
            rate_repository.clone(),
            config.clone(),
        ));
        let audit_logger: Arc<dyn AuditLogger> = Arc::new(DefaultAuditLogger::new(
            audit_repository.clone(),
            config.clone(),
        ));
        let token_generator: Arc<dyn TokenGenerator> = Arc::new(DefaultTokenGenerator::new(
            key_manager.clone(),
            config.clone(),
        ));

        Ok(Self {
            client_repository,
            auth_code_repository,
            token_repository,
            audit_repository,
            rate_repository,
            client_authenticator,
            rate_limiter,
            audit_logger,
            token_generator,
            key_manager,
            config,
        })
    }

    /// Get JWKS from the key manager
    pub fn get_jwks(&self) -> String {
        let key_manager = self.key_manager.lock().unwrap();
        match key_manager.get_jwks() {
            Ok(jwks) => serde_json::to_string(&jwks).unwrap_or_else(|_| "{}".to_string()),
            Err(_) => serde_json::json!({"keys": []}).to_string(),
        }
    }

    /// Cleanup expired data
    pub fn cleanup_expired_data(&self) -> Result<()> {
        // Cleanup expired authorization codes
        let _ = self.auth_code_repository.cleanup_expired_codes();

        // Cleanup expired tokens
        let _ = self.token_repository.cleanup_expired_tokens();

        // Cleanup old audit logs (keep for 30 days)
        let _ = self.audit_repository.cleanup_old_audit_logs(30);

        // Cleanup old rate limits
        let _ = self.rate_repository.cleanup_old_rate_limits();

        Ok(())
    }
}