summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml1
-rw-r--r--REFACTORING_PLAN.md289
-rw-r--r--src/bin/debug.rs28
-rw-r--r--src/container.rs103
-rw-r--r--src/database.rs117
-rw-r--r--src/http/mod.rs86
-rw-r--r--src/lib.rs5
-rw-r--r--src/main.rs134
-rw-r--r--src/oauth/mod.rs4
-rw-r--r--src/oauth/pkce.rs18
-rw-r--r--src/oauth/server.rs8
-rw-r--r--src/oauth/service.rs566
-rw-r--r--src/oauth/types.rs17
-rw-r--r--src/repositories/mod.rs52
-rw-r--r--src/repositories/sqlite.rs164
-rw-r--r--src/services/implementations.rs217
-rw-r--r--src/services/mod.rs49
18 files changed, 1781 insertions, 78 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 21d1d82..f91c48a 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -992,6 +992,7 @@ dependencies = [
"serde",
"serde_json",
"sha2",
+ "subtle",
"tokio",
"url",
"urlencoding",
diff --git a/Cargo.toml b/Cargo.toml
index d728ba0..039bb5f 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -30,3 +30,4 @@ rusqlite = { version = "0.32", features = ["bundled", "chrono"] }
chrono = { version = "0.4", features = ["serde"] }
tokio = { version = "1.0", features = ["full"] }
anyhow = "1.0"
+subtle = "2.4"
diff --git a/REFACTORING_PLAN.md b/REFACTORING_PLAN.md
new file mode 100644
index 0000000..ec24449
--- /dev/null
+++ b/REFACTORING_PLAN.md
@@ -0,0 +1,289 @@
+# OAuth2 STS SOLID Refactoring Plan
+
+## Overview
+This document outlines refactoring opportunities to better align with SOLID principles.
+
+## 1. Extract Grant Type Handlers (SRP + OCP)
+
+### Current Issue
+`OAuthServer::handle_token()` contains all grant type logic in one method.
+
+### Proposed Solution
+```rust
+trait GrantHandler {
+ fn can_handle(&self, grant_type: &str) -> bool;
+ fn handle_grant(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String>;
+}
+
+struct AuthorizationCodeGrantHandler {
+ client_authenticator: Arc<dyn ClientAuthenticator>,
+ token_generator: Arc<dyn TokenGenerator>,
+ code_repository: Arc<dyn AuthCodeRepository>,
+ token_repository: Arc<dyn TokenRepository>,
+ audit_logger: Arc<dyn AuditLogger>,
+}
+
+struct RefreshTokenGrantHandler {
+ // Similar dependencies
+}
+
+struct GrantHandlerRegistry {
+ handlers: Vec<Arc<dyn GrantHandler>>,
+}
+```
+
+### Benefits
+- ✅ Easy to add new grant types (Client Credentials, Device Code, etc.)
+- ✅ Each handler has single responsibility
+- ✅ Testable in isolation
+
+## 2. Create Repository Abstractions (DIP + SRP)
+
+### Current Issue
+Monolithic `Database` struct violates SRP and makes testing difficult.
+
+### Proposed Solution
+```rust
+trait ClientRepository {
+ fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>;
+ fn authenticate_client(&self, client_id: &str, secret: &str) -> Result<bool>;
+ fn is_redirect_uri_valid(&self, client_id: &str, uri: &str) -> bool;
+}
+
+trait TokenRepository {
+ fn create_access_token(&self, token: &DbAccessToken) -> Result<()>;
+ fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>>;
+ fn revoke_access_token(&self, token_hash: &str) -> Result<()>;
+}
+
+trait AuthCodeRepository {
+ fn create_auth_code(&self, code: &DbAuthCode) -> Result<()>;
+ fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>>;
+ fn mark_auth_code_used(&self, code: &str) -> Result<()>;
+}
+
+// SQLite implementations
+struct SqliteClientRepository {
+ database: Arc<Mutex<Database>>,
+}
+
+struct SqliteTokenRepository {
+ database: Arc<Mutex<Database>>,
+}
+```
+
+### Benefits
+- ✅ Easy to swap database backends (PostgreSQL, Redis, etc.)
+- ✅ Better testability with mock repositories
+- ✅ Clear separation of concerns
+
+## 3. Extract Authentication Strategy (OCP + SRP)
+
+### Current Issue
+Client authentication logic scattered throughout token handlers.
+
+### Proposed Solution
+```rust
+trait ClientAuthenticator {
+ fn authenticate(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(String, String), String>; // Returns (client_id, client_secret)
+}
+
+struct BasicAuthenticator {
+ client_repository: Arc<dyn ClientRepository>,
+}
+
+struct PostAuthenticator {
+ client_repository: Arc<dyn ClientRepository>,
+}
+
+struct CompositeAuthenticator {
+ authenticators: Vec<Arc<dyn ClientAuthenticator>>,
+}
+```
+
+### Benefits
+- ✅ Easy to add new authentication methods (JWT, mTLS, etc.)
+- ✅ Clear separation of authentication concerns
+- ✅ Configurable authentication strategies
+
+## 4. Extract Token Generation (SRP + OCP)
+
+### Current Issue
+Token generation logic embedded in `OAuthServer`.
+
+### Proposed Solution
+```rust
+trait TokenGenerator {
+ fn generate_access_token(
+ &self,
+ user_id: &str,
+ client_id: &str,
+ scope: &Option<String>,
+ token_id: &str,
+ ) -> Result<String>;
+
+ fn generate_refresh_token(
+ &self,
+ client_id: &str,
+ user_id: &str,
+ scope: &Option<String>,
+ ) -> Result<String>;
+}
+
+struct JwtTokenGenerator {
+ key_manager: Arc<Mutex<KeyManager>>,
+ config: Config,
+}
+
+struct OpaqueTokenGenerator {
+ // For opaque token implementation
+}
+```
+
+### Benefits
+- ✅ Easy to switch token formats (JWT vs opaque)
+- ✅ Configurable token generation strategies
+- ✅ Better testability
+
+## 5. Extract Cross-Cutting Concerns
+
+### Rate Limiting
+```rust
+trait RateLimiter {
+ fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()>;
+}
+
+struct DatabaseRateLimiter {
+ rate_repository: Arc<dyn RateRepository>,
+ config: Config,
+}
+
+struct RedisRateLimiter {
+ redis_client: redis::Client,
+ config: Config,
+}
+```
+
+### Audit Logging
+```rust
+trait AuditLogger {
+ fn log_event(&self, event: AuditEvent) -> Result<()>;
+}
+
+struct DatabaseAuditLogger {
+ audit_repository: Arc<dyn AuditRepository>,
+}
+
+struct JsonFileAuditLogger {
+ file_path: PathBuf,
+}
+```
+
+## 6. Refactor HTTP Layer (SRP)
+
+### Current Issue
+`Server` mixes HTTP protocol with OAuth business logic.
+
+### Proposed Solution
+```rust
+struct HttpServer {
+ router: Router,
+ config: Config,
+}
+
+struct Router {
+ oauth_handler: Arc<OAuthHandler>,
+ static_handler: Arc<StaticHandler>,
+}
+
+struct OAuthHandler {
+ oauth_service: Arc<OAuthService>, // Renamed from OAuthServer
+}
+
+// Clean separation between HTTP concerns and OAuth business logic
+```
+
+## 7. Dependency Injection Container
+
+### Proposed Solution
+```rust
+struct ServiceContainer {
+ // Repositories
+ client_repository: Arc<dyn ClientRepository>,
+ token_repository: Arc<dyn TokenRepository>,
+ auth_code_repository: Arc<dyn AuthCodeRepository>,
+
+ // Services
+ client_authenticator: Arc<dyn ClientAuthenticator>,
+ token_generator: Arc<dyn TokenGenerator>,
+ rate_limiter: Arc<dyn RateLimiter>,
+ audit_logger: Arc<dyn AuditLogger>,
+
+ // Grant handlers
+ grant_registry: Arc<GrantHandlerRegistry>,
+}
+
+impl ServiceContainer {
+ fn new(config: &Config) -> Result<Self> {
+ // Wire up all dependencies
+ }
+}
+```
+
+## Implementation Strategy
+
+### Phase 1: Repository Extraction
+1. Create repository traits
+2. Move database methods to specific repositories
+3. Update `OAuthServer` to use repositories
+
+### Phase 2: Grant Handler Extraction
+1. Create `GrantHandler` trait
+2. Extract authorization code handler
+3. Extract refresh token handler
+4. Create registry
+
+### Phase 3: Cross-Cutting Concerns
+1. Extract rate limiting
+2. Extract audit logging
+3. Extract authentication
+
+### Phase 4: HTTP Layer Cleanup
+1. Separate HTTP protocol from business logic
+2. Create clean request/response handlers
+
+### Phase 5: Dependency Injection
+1. Create service container
+2. Wire up all dependencies
+3. Update main.rs to use container
+
+## Benefits of This Refactoring
+
+### Maintainability
+- ✅ Easier to understand and modify individual components
+- ✅ Clear separation of concerns
+- ✅ Reduced coupling between components
+
+### Extensibility
+- ✅ Easy to add new grant types
+- ✅ Easy to swap implementations (database, token format, etc.)
+- ✅ Configurable strategies for cross-cutting concerns
+
+### Testability
+- ✅ Each component can be tested in isolation
+- ✅ Easy to create mock implementations
+- ✅ Better unit test coverage
+
+### Production Readiness
+- ✅ Easy to scale different components independently
+- ✅ Better observability and monitoring capabilities
+- ✅ More flexible deployment options \ No newline at end of file
diff --git a/src/bin/debug.rs b/src/bin/debug.rs
index 6d80848..e05446b 100644
--- a/src/bin/debug.rs
+++ b/src/bin/debug.rs
@@ -1,21 +1,25 @@
fn main() {
let config = sts::Config::from_env();
println!("Config loaded: {}", config.bind_addr);
+
+ // Try the old-style server creation
let server = sts::http::Server::new(config.clone());
println!("Server result: {:?}", server.is_ok());
- if let Ok(server) = server {
- let oauth_server = &server.oauth_server;
- let jwks = oauth_server.get_jwks();
- println!("JWKS length: {}", jwks.len());
- println!(
- "JWKS: {}",
- if jwks.len() > 100 {
- &jwks[..100]
- } else {
- &jwks
- }
- );
+ if let Ok(_server) = server {
+ // Create OAuth server directly to test JWKS
+ if let Ok(oauth_server) = sts::OAuthServer::new(&config) {
+ let jwks = oauth_server.get_jwks();
+ println!("JWKS length: {}", jwks.len());
+ println!(
+ "JWKS: {}",
+ if jwks.len() > 100 {
+ &jwks[..100]
+ } else {
+ &jwks
+ }
+ );
+ }
}
let metadata = serde_json::json!({
diff --git a/src/container.rs b/src/container.rs
new file mode 100644
index 0000000..3a4b13e
--- /dev/null
+++ b/src/container.rs
@@ -0,0 +1,103 @@
+use crate::config::Config;
+use crate::database::Database;
+use crate::keys::KeyManager;
+use crate::repositories::*;
+use crate::services::implementations::*;
+use crate::services::*;
+use anyhow::Result;
+use std::sync::{Arc, Mutex};
+
+/// Dependency injection container for all services and repositories
+pub struct ServiceContainer {
+ // Repositories
+ pub client_repository: Arc<dyn ClientRepository>,
+ pub auth_code_repository: Arc<dyn AuthCodeRepository>,
+ pub token_repository: Arc<dyn TokenRepository>,
+ pub audit_repository: Arc<dyn AuditRepository>,
+ pub rate_repository: Arc<dyn RateRepository>,
+
+ // Services
+ pub client_authenticator: Arc<dyn ClientAuthenticator>,
+ pub rate_limiter: Arc<dyn RateLimiter>,
+ pub audit_logger: Arc<dyn AuditLogger>,
+ pub token_generator: Arc<dyn TokenGenerator>,
+
+ // Core components
+ pub key_manager: Arc<Mutex<KeyManager>>,
+ pub config: Config,
+}
+
+impl ServiceContainer {
+ pub fn new(config: Config, database: Arc<Mutex<Database>>) -> Result<Self> {
+ // Create repositories
+ let client_repository: Arc<dyn ClientRepository> =
+ Arc::new(SqliteClientRepository::new(database.clone()));
+ let auth_code_repository: Arc<dyn AuthCodeRepository> =
+ Arc::new(SqliteAuthCodeRepository::new(database.clone()));
+ let token_repository: Arc<dyn TokenRepository> =
+ Arc::new(SqliteTokenRepository::new(database.clone()));
+ let audit_repository: Arc<dyn AuditRepository> =
+ Arc::new(SqliteAuditRepository::new(database.clone()));
+ let rate_repository: Arc<dyn RateRepository> =
+ Arc::new(SqliteRateRepository::new(database.clone()));
+
+ // Create key manager
+ let key_manager = Arc::new(Mutex::new(KeyManager::new(database.clone())?));
+
+ // Create services
+ let client_authenticator: Arc<dyn ClientAuthenticator> =
+ Arc::new(DefaultClientAuthenticator::new(client_repository.clone()));
+ let rate_limiter: Arc<dyn RateLimiter> = Arc::new(DefaultRateLimiter::new(
+ rate_repository.clone(),
+ config.clone(),
+ ));
+ let audit_logger: Arc<dyn AuditLogger> = Arc::new(DefaultAuditLogger::new(
+ audit_repository.clone(),
+ config.clone(),
+ ));
+ let token_generator: Arc<dyn TokenGenerator> = Arc::new(DefaultTokenGenerator::new(
+ key_manager.clone(),
+ config.clone(),
+ ));
+
+ Ok(Self {
+ client_repository,
+ auth_code_repository,
+ token_repository,
+ audit_repository,
+ rate_repository,
+ client_authenticator,
+ rate_limiter,
+ audit_logger,
+ token_generator,
+ key_manager,
+ config,
+ })
+ }
+
+ /// Get JWKS from the key manager
+ pub fn get_jwks(&self) -> String {
+ let key_manager = self.key_manager.lock().unwrap();
+ match key_manager.get_jwks() {
+ Ok(jwks) => serde_json::to_string(&jwks).unwrap_or_else(|_| "{}".to_string()),
+ Err(_) => serde_json::json!({"keys": []}).to_string(),
+ }
+ }
+
+ /// Cleanup expired data
+ pub fn cleanup_expired_data(&self) -> Result<()> {
+ // Cleanup expired authorization codes
+ let _ = self.auth_code_repository.cleanup_expired_codes();
+
+ // Cleanup expired tokens
+ let _ = self.token_repository.cleanup_expired_tokens();
+
+ // Cleanup old audit logs (keep for 30 days)
+ let _ = self.audit_repository.cleanup_old_audit_logs(30);
+
+ // Cleanup old rate limits
+ let _ = self.rate_repository.cleanup_old_rate_limits();
+
+ Ok(())
+ }
+}
diff --git a/src/database.rs b/src/database.rs
index 2472d1a..5251dac 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -665,6 +665,123 @@ impl Database {
)?;
Ok(affected)
}
+
+ // Additional methods needed for repository patterns
+ pub fn update_oauth_client(&self, client: &DbOAuthClient) -> Result<()> {
+ self.conn.execute(
+ "UPDATE oauth_clients SET
+ client_secret_hash = ?2, client_name = ?3, redirect_uris = ?4,
+ scopes = ?5, grant_types = ?6, response_types = ?7,
+ updated_at = ?8, is_active = ?9
+ WHERE client_id = ?1",
+ params![
+ client.client_id,
+ client.client_secret_hash,
+ client.client_name,
+ client.redirect_uris,
+ client.scopes,
+ client.grant_types,
+ client.response_types,
+ client.updated_at.to_rfc3339(),
+ client.is_active
+ ],
+ )?;
+ Ok(())
+ }
+
+ pub fn delete_oauth_client(&self, client_id: &str) -> Result<()> {
+ self.conn.execute(
+ "DELETE FROM oauth_clients WHERE client_id = ?1",
+ [client_id],
+ )?;
+ Ok(())
+ }
+
+ pub fn list_oauth_clients(&self) -> Result<Vec<DbOAuthClient>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, client_id, client_secret_hash, client_name, redirect_uris,
+ scopes, grant_types, response_types, created_at, updated_at, is_active
+ FROM oauth_clients ORDER BY created_at DESC",
+ )?;
+
+ let clients = stmt
+ .query_map([], |row| {
+ Ok(DbOAuthClient {
+ id: row.get(0)?,
+ client_id: row.get(1)?,
+ client_secret_hash: row.get(2)?,
+ client_name: row.get(3)?,
+ redirect_uris: row.get(4)?,
+ scopes: row.get(5)?,
+ grant_types: row.get(6)?,
+ response_types: row.get(7)?,
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(8)?)
+ .map_err(|e| {
+ rusqlite::Error::FromSqlConversionFailure(
+ 8,
+ rusqlite::types::Type::Text,
+ Box::new(e),
+ )
+ })?
+ .with_timezone(&Utc),
+ updated_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(9)?)
+ .map_err(|e| {
+ rusqlite::Error::FromSqlConversionFailure(
+ 9,
+ rusqlite::types::Type::Text,
+ Box::new(e),
+ )
+ })?
+ .with_timezone(&Utc),
+ is_active: row.get(10)?,
+ })
+ })?
+ .collect::<Result<Vec<_>, _>>()?;
+
+ Ok(clients)
+ }
+
+ pub fn get_audit_logs(&self, limit: i32) -> Result<Vec<DbAuditLog>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, event_type, client_id, user_id, ip_address, user_agent, details, created_at, success
+ FROM audit_logs ORDER BY created_at DESC LIMIT ?1"
+ )?;
+
+ let logs = stmt
+ .query_map([limit], |row| {
+ Ok(DbAuditLog {
+ id: row.get(0)?,
+ event_type: row.get(1)?,
+ client_id: row.get(2)?,
+ user_id: row.get(3)?,
+ ip_address: row.get(4)?,
+ user_agent: row.get(5)?,
+ details: row.get(6)?,
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?)
+ .map_err(|e| {
+ rusqlite::Error::FromSqlConversionFailure(
+ 7,
+ rusqlite::types::Type::Text,
+ Box::new(e),
+ )
+ })?
+ .with_timezone(&Utc),
+ success: row.get(8)?,
+ })
+ })?
+ .collect::<Result<Vec<_>, _>>()?;
+
+ Ok(logs)
+ }
+
+ pub fn cleanup_old_rate_limits(&self) -> Result<()> {
+ let cutoff = Utc::now() - chrono::Duration::hours(24); // Clean up rate limits older than 24 hours
+ self.conn.execute(
+ "DELETE FROM rate_limits WHERE created_at < ?1",
+ [cutoff.to_rfc3339()],
+ )?;
+ Ok(())
+ }
}
#[cfg(test)]
diff --git a/src/http/mod.rs b/src/http/mod.rs
index 1bc7951..778a3de 100644
--- a/src/http/mod.rs
+++ b/src/http/mod.rs
@@ -1,21 +1,38 @@
use crate::config::Config;
-use crate::oauth::OAuthServer;
+use crate::container::ServiceContainer;
+use crate::oauth::{OAuthServer, OAuthService};
use std::collections::HashMap;
use std::fs;
use std::io::prelude::*;
use std::net::{TcpListener, TcpStream};
+use std::sync::Arc;
use url::Url;
pub struct Server {
config: Config,
- pub oauth_server: OAuthServer,
+ oauth_server: Option<OAuthServer>,
+ oauth_service: Option<Arc<ServiceContainer>>,
}
impl Server {
pub fn new(config: Config) -> Result<Server, Box<dyn std::error::Error>> {
Ok(Server {
- oauth_server: OAuthServer::new(&config)
- .map_err(|e| format!("Failed to create OAuth server: {}", e))?,
+ oauth_server: Some(
+ OAuthServer::new(&config)
+ .map_err(|e| format!("Failed to create OAuth server: {}", e))?,
+ ),
+ oauth_service: None,
+ config,
+ })
+ }
+
+ pub fn new_with_container(
+ config: Config,
+ container: Arc<ServiceContainer>,
+ ) -> Result<Server, Box<dyn std::error::Error>> {
+ Ok(Server {
+ oauth_server: None,
+ oauth_service: Some(container),
config,
})
}
@@ -194,7 +211,13 @@ impl Server {
}
fn handle_jwks(&self, stream: &mut TcpStream) {
- let jwks = self.oauth_server.get_jwks();
+ let jwks = if let Some(ref oauth_server) = self.oauth_server {
+ oauth_server.get_jwks()
+ } else if let Some(ref container) = self.oauth_service {
+ container.get_jwks()
+ } else {
+ "{\"keys\":[]}".to_string()
+ };
self.send_json_response(stream, 200, "OK", &jwks);
}
@@ -204,7 +227,16 @@ impl Server {
params: &HashMap<String, String>,
ip_address: Option<String>,
) {
- match self.oauth_server.handle_authorize(params, ip_address) {
+ let result = if let Some(ref oauth_server) = self.oauth_server {
+ oauth_server.handle_authorize(params, ip_address)
+ } else if let Some(ref container) = self.oauth_service {
+ let oauth_service = OAuthService::new(container.clone());
+ oauth_service.handle_authorize(params, ip_address)
+ } else {
+ Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string())
+ };
+
+ match result {
Ok(redirect_url) => {
let security_headers = self.get_security_headers();
let response = format!(
@@ -227,10 +259,16 @@ impl Server {
// Extract Authorization header from request
let auth_header = self.extract_auth_header(request);
- match self
- .oauth_server
- .handle_token(&form_params, auth_header.as_deref(), ip_address)
- {
+ let result = if let Some(ref oauth_server) = self.oauth_server {
+ oauth_server.handle_token(&form_params, auth_header.as_deref(), ip_address)
+ } else if let Some(ref container) = self.oauth_service {
+ let oauth_service = OAuthService::new(container.clone());
+ oauth_service.handle_token(&form_params, auth_header.as_deref(), ip_address)
+ } else {
+ Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string())
+ };
+
+ match result {
Ok(token_response) => {
self.send_json_response(stream, 200, "OK", &token_response);
}
@@ -245,10 +283,16 @@ impl Server {
let form_params = self.parse_form_data(&body);
let auth_header = self.extract_auth_header(request);
- match self
- .oauth_server
- .handle_token_introspection(&form_params, auth_header.as_deref())
- {
+ let result = if let Some(ref oauth_server) = self.oauth_server {
+ oauth_server.handle_token_introspection(&form_params, auth_header.as_deref())
+ } else if let Some(ref container) = self.oauth_service {
+ let oauth_service = OAuthService::new(container.clone());
+ oauth_service.handle_token_introspection(&form_params, auth_header.as_deref())
+ } else {
+ Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string())
+ };
+
+ match result {
Ok(introspection_response) => {
self.send_json_response(stream, 200, "OK", &introspection_response);
}
@@ -263,10 +307,16 @@ impl Server {
let form_params = self.parse_form_data(&body);
let auth_header = self.extract_auth_header(request);
- match self
- .oauth_server
- .handle_token_revocation(&form_params, auth_header.as_deref())
- {
+ let result = if let Some(ref oauth_server) = self.oauth_server {
+ oauth_server.handle_token_revocation(&form_params, auth_header.as_deref())
+ } else if let Some(ref container) = self.oauth_service {
+ let oauth_service = OAuthService::new(container.clone());
+ oauth_service.handle_token_revocation(&form_params, auth_header.as_deref())
+ } else {
+ Err("{\"error\": \"server_error\", \"error_description\": \"No OAuth service available\"}".to_string())
+ };
+
+ match result {
Ok(_) => {
self.send_empty_response(stream, 200, "OK");
}
diff --git a/src/lib.rs b/src/lib.rs
index eef2cbf..2b25e8a 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,14 +1,17 @@
pub mod clients;
pub mod config;
+pub mod container;
pub mod database;
pub mod http;
pub mod keys;
pub mod migrations;
pub mod oauth;
+pub mod repositories;
+pub mod services;
pub use clients::ClientManager;
pub use config::Config;
pub use database::Database;
pub use http::Server;
pub use migrations::MigrationRunner;
-pub use oauth::OAuthServer;
+pub use oauth::{OAuthServer, OAuthService};
diff --git a/src/main.rs b/src/main.rs
index ac47a5e..4873a1d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,22 +1,36 @@
+use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
-use sts::Config;
+use sts::container::ServiceContainer;
use sts::http::Server;
+use sts::{Config, Database};
fn main() {
let config = Config::from_env();
- let server = Server::new(config.clone()).expect("Failed to create server");
+
+ // Initialize database
+ let database = Database::new(&config.database_path).expect("Failed to initialize database");
+ let database = Arc::new(Mutex::new(database));
+
+ // Initialize service container with dependency injection
+ let container = ServiceContainer::new(config.clone(), database.clone())
+ .expect("Failed to create service container");
+ let container = Arc::new(container);
+
+ let server = Server::new_with_container(config.clone(), container.clone())
+ .expect("Failed to create server");
// Start cleanup task in background
+ let cleanup_container = container.clone();
let cleanup_config = config.clone();
thread::spawn(move || {
loop {
thread::sleep(Duration::from_secs(
cleanup_config.cleanup_interval_hours as u64 * 3600,
));
- // Note: In the current implementation, we don't have direct access to the OAuth server
- // from here to call cleanup_expired_data(). In a production implementation,
- // you'd want to structure this differently or use a background job queue.
+ if let Err(e) = cleanup_container.cleanup_expired_data() {
+ eprintln!("Cleanup task failed: {}", e);
+ }
}
});
@@ -139,11 +153,16 @@ mod tests {
// Step 1: Authorization request
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
auth_params.insert("state".to_string(), "test_state".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
// Extract the authorization code from redirect URL
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
@@ -160,11 +179,13 @@ mod tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
+ let token_result = oauth_server
+ .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
+ .expect("Token request failed");
// Parse token response
- let token_response: serde_json::Value = serde_json::from_str(&token_result)
- .expect("Invalid token response JSON");
+ let token_response: serde_json::Value =
+ serde_json::from_str(&token_result).expect("Invalid token response JSON");
assert_eq!(token_response["token_type"], "Bearer");
assert_eq!(token_response["expires_in"], 3600);
@@ -173,7 +194,8 @@ mod tests {
let access_token = token_response["access_token"].as_str().unwrap();
// Step 3: Verify the JWT token has RSA signature and key ID
- let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
+ let header =
+ jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
assert!(header.kid.is_some());
assert!(!header.kid.as_ref().unwrap().is_empty());
@@ -187,10 +209,15 @@ mod tests {
// Generate a token
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -204,9 +231,11 @@ mod tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
- let token_response: serde_json::Value = serde_json::from_str(&token_result)
- .expect("Invalid token response JSON");
+ let token_result = oauth_server
+ .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
+ .expect("Token request failed");
+ let token_response: serde_json::Value =
+ serde_json::from_str(&token_result).expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
// Get the JWKS
@@ -214,7 +243,8 @@ mod tests {
let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
// Decode the token header to get the key ID
- let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
+ let header =
+ jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
let kid = header.kid.as_ref().expect("No key ID in token");
// Find the matching key in JWKS
@@ -237,11 +267,16 @@ mod tests {
// Generate a token through the full flow
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
auth_params.insert("scope".to_string(), "openid profile".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -255,16 +290,18 @@ mod tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
- let token_response: serde_json::Value = serde_json::from_str(&token_result)
- .expect("Invalid token response JSON");
+ let token_result = oauth_server
+ .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
+ .expect("Token request failed");
+ let token_response: serde_json::Value =
+ serde_json::from_str(&token_result).expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
// Decode the token without verification to check claims
let _token_data = jsonwebtoken::decode::<serde_json::Value>(
access_token,
&jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing
- &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256)
+ &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256),
);
// Since we can't validate with a dummy key, we'll just verify the structure
@@ -275,7 +312,8 @@ mod tests {
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.expect("Failed to decode payload");
- let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON");
+ let claims: serde_json::Value =
+ serde_json::from_slice(&payload).expect("Invalid claims JSON");
assert!(claims["sub"].is_string());
assert!(claims["iss"].is_string());
@@ -293,7 +331,10 @@ mod tests {
let mut params = HashMap::new();
params.insert("client_id".to_string(), "invalid_client".to_string());
- params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
params.insert("response_type".to_string(), "code".to_string());
let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
@@ -308,7 +349,10 @@ mod tests {
let mut params = HashMap::new();
params.insert("client_id".to_string(), "test_client".to_string());
- params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string());
+ params.insert(
+ "redirect_uri".to_string(),
+ "https://evil.com/callback".to_string(),
+ );
params.insert("response_type".to_string(), "code".to_string());
let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
@@ -323,7 +367,10 @@ mod tests {
let mut params = HashMap::new();
params.insert("client_id".to_string(), "test_client".to_string());
- params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
params.insert("response_type".to_string(), "code".to_string());
params.insert("scope".to_string(), "invalid_scope".to_string());
@@ -340,10 +387,15 @@ mod tests {
// First get an authorization code
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -371,10 +423,15 @@ mod tests {
// First get an authorization code
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -402,10 +459,15 @@ mod tests {
// First get an authorization code
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -422,7 +484,11 @@ mod tests {
// test_client:test_secret encoded in base64
let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ=";
- let result = oauth_server.handle_token(&token_params, Some(auth_header), Some("127.0.0.1".to_string()));
+ let result = oauth_server.handle_token(
+ &token_params,
+ Some(auth_header),
+ Some("127.0.0.1".to_string()),
+ );
assert!(result.is_ok());
}
}
diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs
index 4b18bb3..b2d46fa 100644
--- a/src/oauth/mod.rs
+++ b/src/oauth/mod.rs
@@ -1,9 +1,11 @@
pub mod pkce;
pub mod server;
+pub mod service;
pub mod types;
pub use pkce::{
- generate_code_challenge, generate_code_verifier, verify_code_challenge, CodeChallengeMethod,
+ CodeChallengeMethod, generate_code_challenge, generate_code_verifier, verify_code_challenge,
};
pub use server::OAuthServer;
+pub use service::OAuthService;
pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse};
diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs
index 406d364..0dfc1f8 100644
--- a/src/oauth/pkce.rs
+++ b/src/oauth/pkce.rs
@@ -1,5 +1,5 @@
-use anyhow::{anyhow, Result};
-use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
+use anyhow::{Result, anyhow};
+use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, PartialEq)]
@@ -124,12 +124,14 @@ mod tests {
assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err());
// Invalid characters
- assert!(verify_code_challenge(
- "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!",
- "challenge",
- &CodeChallengeMethod::Plain
- )
- .is_err());
+ assert!(
+ verify_code_challenge(
+ "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!",
+ "challenge",
+ &CodeChallengeMethod::Plain
+ )
+ .is_err()
+ );
}
#[test]
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 7fd8b9c..37c3cbc 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -1,12 +1,12 @@
-use crate::clients::{parse_basic_auth, ClientManager};
+use crate::clients::{ClientManager, parse_basic_auth};
use crate::config::Config;
use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode};
use crate::keys::KeyManager;
-use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod};
+use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge};
use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse};
-use anyhow::{anyhow, Result};
+use anyhow::{Result, anyhow};
use chrono::{Duration, Utc};
-use jsonwebtoken::{encode, Algorithm, Header};
+use jsonwebtoken::{Algorithm, Header, encode};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
diff --git a/src/oauth/service.rs b/src/oauth/service.rs
new file mode 100644
index 0000000..1b4eb49
--- /dev/null
+++ b/src/oauth/service.rs
@@ -0,0 +1,566 @@
+use crate::container::ServiceContainer;
+use crate::database::{DbAccessToken, DbAuthCode};
+use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge};
+use crate::oauth::types::{ErrorResponse, TokenIntrospectionResponse, TokenResponse};
+use anyhow::Result;
+use chrono::{Duration, Utc};
+use sha2::{Digest, Sha256};
+use std::collections::HashMap;
+use std::sync::Arc;
+use url::Url;
+use uuid::Uuid;
+
+/// Refactored OAuth service using dependency injection
+pub struct OAuthService {
+ container: Arc<ServiceContainer>,
+}
+
+impl OAuthService {
+ pub fn new(container: Arc<ServiceContainer>) -> Self {
+ Self { container }
+ }
+
+ pub fn get_jwks(&self) -> String {
+ self.container.get_jwks()
+ }
+
+ pub fn handle_authorize(
+ &self,
+ params: &HashMap<String, String>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+
+ let redirect_uri = params
+ .get("redirect_uri")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?;
+
+ let response_type = params
+ .get("response_type")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?;
+
+ // Rate limiting check
+ if let Err(e) = self
+ .container
+ .rate_limiter
+ .check_rate_limit(&format!("client:{}", client_id), "/authorize")
+ {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_rate_limited",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
+ return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
+ }
+
+ // Validate client exists
+ let client = match self.container.client_repository.get_client(client_id) {
+ Ok(Some(client)) => client,
+ Ok(None) => {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_invalid_client",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_client", "Invalid client_id"));
+ }
+ Err(_) => {
+ return Err(self.error_response("server_error", "Internal server error"));
+ }
+ };
+
+ // Validate redirect URI
+ let redirect_uris: Vec<String> =
+ serde_json::from_str(&client.redirect_uris).unwrap_or_else(|_| vec![]);
+
+ if !redirect_uris.contains(redirect_uri) {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_invalid_redirect_uri",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(redirect_uri),
+ );
+ return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
+ }
+
+ // Validate requested scopes
+ let scope = params.get("scope").cloned();
+ if let Some(ref scope_str) = scope {
+ let client_scopes: Vec<&str> = client.scopes.split_whitespace().collect();
+ let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
+
+ for requested_scope in &requested_scopes {
+ if !client_scopes.contains(requested_scope) {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_invalid_scope",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ scope.as_deref(),
+ );
+ return Err(self.error_response("invalid_scope", "Invalid scope"));
+ }
+ }
+ }
+
+ if response_type != "code" {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_unsupported_response_type",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(response_type),
+ );
+ return Err(self.error_response(
+ "unsupported_response_type",
+ "Only code response type supported",
+ ));
+ }
+
+ // PKCE validation (RFC 7636)
+ let code_challenge = params.get("code_challenge");
+ let code_challenge_method = params
+ .get("code_challenge_method")
+ .map(|method| CodeChallengeMethod::from_str(method))
+ .transpose()
+ .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?;
+
+ // For public clients, PKCE is required
+ if client.client_id.starts_with("public_") && code_challenge.is_none() {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_missing_pkce",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_request", "PKCE required for public clients"));
+ }
+
+ let code = Uuid::new_v4().to_string();
+ let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration
+
+ let db_auth_code = DbAuthCode {
+ id: 0, // Will be set by database
+ code: code.clone(),
+ client_id: client_id.clone(),
+ user_id: "test_user".to_string(), // In a real implementation, this would come from authentication
+ redirect_uri: redirect_uri.clone(),
+ scope: scope.clone(),
+ expires_at,
+ created_at: Utc::now(),
+ is_used: false,
+ code_challenge: code_challenge.cloned(),
+ code_challenge_method: code_challenge_method
+ .as_ref()
+ .map(|m| m.as_str().to_string()),
+ };
+
+ // Save to database
+ if let Err(_) = self
+ .container
+ .auth_code_repository
+ .create_auth_code(&db_auth_code)
+ {
+ return Err(self.error_response("server_error", "Failed to create authorization code"));
+ }
+
+ let mut redirect_url = Url::parse(redirect_uri)
+ .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?;
+
+ redirect_url.query_pairs_mut().append_pair("code", &code);
+
+ if let Some(state) = params.get("state") {
+ redirect_url.query_pairs_mut().append_pair("state", state);
+ }
+
+ let _ = self.container.audit_logger.log_event(
+ "authorize_success",
+ Some(client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
+
+ Ok(redirect_url.to_string())
+ }
+
+ pub fn handle_token(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let grant_type = params
+ .get("grant_type")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
+
+ match grant_type.as_str() {
+ "authorization_code" => {
+ self.handle_authorization_code_grant(params, auth_header, ip_address)
+ }
+ "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address),
+ _ => {
+ let _ = self.container.audit_logger.log_event(
+ "token_unsupported_grant_type",
+ None,
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(grant_type),
+ );
+ Err(self.error_response("unsupported_grant_type", "Unsupported grant type"))
+ }
+ }
+ }
+
+ fn handle_authorization_code_grant(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let code = params
+ .get("code")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?;
+
+ // Client authentication using injected service
+ let (client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Rate limiting check
+ if let Err(e) = self
+ .container
+ .rate_limiter
+ .check_rate_limit(&format!("client:{}", client_id), "/token")
+ {
+ let _ = self.container.audit_logger.log_event(
+ "token_rate_limited",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
+ return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
+ }
+
+ // Get and validate authorization code
+ let auth_code = match self.container.auth_code_repository.get_auth_code(code) {
+ Ok(Some(auth_code)) => auth_code,
+ Ok(None) => {
+ let _ = self.container.audit_logger.log_event(
+ "token_invalid_code",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(
+ self.error_response("invalid_grant", "Invalid or expired authorization code")
+ );
+ }
+ Err(_) => {
+ return Err(self.error_response("server_error", "Internal server error"));
+ }
+ };
+
+ // Validate code hasn't been used and hasn't expired
+ if auth_code.is_used {
+ let _ = self.container.audit_logger.log_event(
+ "token_code_reuse",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(self.error_response("invalid_grant", "Authorization code already used"));
+ }
+
+ if Utc::now() > auth_code.expires_at {
+ let _ = self.container.audit_logger.log_event(
+ "token_code_expired",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(self.error_response("invalid_grant", "Authorization code expired"));
+ }
+
+ if auth_code.client_id != client_id {
+ let _ = self.container.audit_logger.log_event(
+ "token_client_mismatch",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "Client ID mismatch"));
+ }
+
+ // PKCE validation if code challenge was provided
+ if let Some(code_challenge) = &auth_code.code_challenge {
+ let code_verifier = params.get("code_verifier").ok_or_else(|| {
+ self.error_response("invalid_request", "Missing code_verifier for PKCE")
+ })?;
+
+ let challenge_method = auth_code
+ .code_challenge_method
+ .as_ref()
+ .and_then(|method| CodeChallengeMethod::from_str(method).ok())
+ .unwrap_or(CodeChallengeMethod::Plain);
+
+ if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method)
+ {
+ let _ = self.container.audit_logger.log_event(
+ "token_pkce_verification_failed",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "PKCE verification failed"));
+ }
+ }
+
+ // Mark code as used
+ if let Err(_) = self
+ .container
+ .auth_code_repository
+ .mark_auth_code_used(code)
+ {
+ return Err(self.error_response("server_error", "Failed to mark code as used"));
+ }
+
+ // Generate tokens using injected service
+ let token_id = Uuid::new_v4().to_string();
+ let access_token = self.container.token_generator.generate_access_token(
+ &auth_code.user_id,
+ &client_id,
+ &auth_code.scope,
+ &token_id,
+ )?;
+ let refresh_token = self.container.token_generator.generate_refresh_token(
+ &client_id,
+ &auth_code.user_id,
+ &auth_code.scope,
+ )?;
+
+ // Store token in database for revocation/introspection
+ let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes()));
+ let db_access_token = DbAccessToken {
+ id: 0,
+ token_id: token_id.clone(),
+ client_id: client_id.clone(),
+ user_id: auth_code.user_id.clone(),
+ scope: auth_code.scope.clone(),
+ expires_at: Utc::now() + Duration::hours(1),
+ created_at: Utc::now(),
+ is_revoked: false,
+ token_hash,
+ };
+
+ if let Err(_) = self
+ .container
+ .token_repository
+ .create_access_token(&db_access_token)
+ {
+ return Err(self.error_response("server_error", "Failed to store access token"));
+ }
+
+ let token_response = TokenResponse {
+ access_token,
+ token_type: "Bearer".to_string(),
+ expires_in: 3600,
+ refresh_token: Some(refresh_token),
+ scope: auth_code.scope,
+ };
+
+ let _ = self.container.audit_logger.log_event(
+ "token_success",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
+
+ serde_json::to_string(&token_response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
+ }
+
+ fn handle_refresh_token_grant(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let _refresh_token = params
+ .get("refresh_token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?;
+
+ // Client authentication using injected service
+ let (client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Validate refresh token (implementation would verify token and get user info)
+ // For now, return a simple refresh token response
+ let new_token_id = Uuid::new_v4().to_string();
+ let access_token = self.container.token_generator.generate_access_token(
+ "test_user",
+ &client_id,
+ &None,
+ &new_token_id,
+ )?;
+ let new_refresh_token = self.container.token_generator.generate_refresh_token(
+ &client_id,
+ "test_user",
+ &None,
+ )?;
+
+ let token_response = TokenResponse {
+ access_token,
+ token_type: "Bearer".to_string(),
+ expires_in: 3600,
+ refresh_token: Some(new_refresh_token),
+ scope: None,
+ };
+
+ let _ = self.container.audit_logger.log_event(
+ "refresh_success",
+ Some(&client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
+
+ serde_json::to_string(&token_response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
+ }
+
+ pub fn handle_token_introspection(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<String, String> {
+ let token = params
+ .get("token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?;
+
+ // Authenticate the client making the introspection request using injected service
+ let (_client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Look up token in database using repository
+ let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
+ let db_token = self
+ .container
+ .token_repository
+ .get_access_token(&token_hash)
+ .ok()
+ .flatten();
+
+ let response = if let Some(db_token) = db_token {
+ if !db_token.is_revoked && Utc::now() < db_token.expires_at {
+ TokenIntrospectionResponse {
+ active: true,
+ client_id: Some(db_token.client_id.clone()),
+ username: Some(db_token.user_id.clone()),
+ scope: db_token.scope.clone(),
+ exp: Some(db_token.expires_at.timestamp() as u64),
+ iat: Some(db_token.created_at.timestamp() as u64),
+ sub: Some(db_token.user_id),
+ aud: Some(db_token.client_id),
+ iss: Some(self.container.config.issuer_url.clone()),
+ jti: Some(db_token.token_id),
+ }
+ } else {
+ TokenIntrospectionResponse::inactive()
+ }
+ } else {
+ TokenIntrospectionResponse::inactive()
+ };
+
+ serde_json::to_string(&response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize response"))
+ }
+
+ pub fn handle_token_revocation(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(), String> {
+ let token = params
+ .get("token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?;
+
+ // Authenticate the client making the revocation request using injected service
+ let (client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Revoke token in database using repository
+ let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
+ let _ = self
+ .container
+ .token_repository
+ .revoke_access_token(&token_hash); // Ignore errors as per RFC 7009
+
+ let _ = self.container.audit_logger.log_event(
+ "token_revoked",
+ Some(&client_id),
+ None,
+ None,
+ true,
+ None,
+ );
+
+ Ok(())
+ }
+
+ fn error_response(&self, error: &str, description: &str) -> String {
+ let error_resp = ErrorResponse {
+ error: error.to_string(),
+ error_description: Some(description.to_string()),
+ error_uri: None,
+ };
+ serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string())
+ }
+
+ /// Cleanup expired data using repositories
+ pub fn cleanup_expired_data(&self) -> Result<()> {
+ self.container.cleanup_expired_data()
+ }
+}
diff --git a/src/oauth/types.rs b/src/oauth/types.rs
index 4f2c363..3d1c581 100644
--- a/src/oauth/types.rs
+++ b/src/oauth/types.rs
@@ -76,6 +76,23 @@ pub struct TokenIntrospectionResponse {
pub jti: Option<String>,
}
+impl TokenIntrospectionResponse {
+ pub fn inactive() -> Self {
+ Self {
+ active: false,
+ client_id: None,
+ username: None,
+ scope: None,
+ exp: None,
+ iat: None,
+ sub: None,
+ aud: None,
+ iss: None,
+ jti: None,
+ }
+ }
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenRevocationRequest {
pub token: String,
diff --git a/src/repositories/mod.rs b/src/repositories/mod.rs
new file mode 100644
index 0000000..1685fe0
--- /dev/null
+++ b/src/repositories/mod.rs
@@ -0,0 +1,52 @@
+use crate::database::{DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient};
+use anyhow::Result;
+
+pub mod sqlite;
+
+pub use sqlite::{
+ SqliteAuditRepository, SqliteAuthCodeRepository, SqliteClientRepository, SqliteRateRepository,
+ SqliteTokenRepository,
+};
+
+/// Repository trait for OAuth client operations
+pub trait ClientRepository: Send + Sync {
+ fn get_client(&self, client_id: &str) -> Result<Option<DbOAuthClient>>;
+ fn create_client(&self, client: &DbOAuthClient) -> Result<()>;
+ fn update_client(&self, client: &DbOAuthClient) -> Result<()>;
+ fn delete_client(&self, client_id: &str) -> Result<()>;
+ fn list_clients(&self) -> Result<Vec<DbOAuthClient>>;
+}
+
+/// Repository trait for authorization code operations
+pub trait AuthCodeRepository: Send + Sync {
+ fn create_auth_code(&self, code: &DbAuthCode) -> Result<()>;
+ fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>>;
+ fn mark_auth_code_used(&self, code: &str) -> Result<()>;
+ fn cleanup_expired_codes(&self) -> Result<()>;
+}
+
+/// Repository trait for access token operations
+pub trait TokenRepository: Send + Sync {
+ fn create_access_token(&self, token: &DbAccessToken) -> Result<()>;
+ fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>>;
+ fn revoke_access_token(&self, token_hash: &str) -> Result<()>;
+ fn cleanup_expired_tokens(&self) -> Result<()>;
+}
+
+/// Repository trait for audit log operations
+pub trait AuditRepository: Send + Sync {
+ fn create_audit_log(&self, log: &DbAuditLog) -> Result<()>;
+ fn get_audit_logs(&self, limit: Option<i32>) -> Result<Vec<DbAuditLog>>;
+ fn cleanup_old_audit_logs(&self, days: i32) -> Result<()>;
+}
+
+/// Repository trait for rate limiting operations
+pub trait RateRepository: Send + Sync {
+ fn increment_rate_limit(
+ &self,
+ identifier: &str,
+ endpoint: &str,
+ window_size: i32,
+ ) -> Result<i32>;
+ fn cleanup_old_rate_limits(&self) -> Result<()>;
+}
diff --git a/src/repositories/sqlite.rs b/src/repositories/sqlite.rs
new file mode 100644
index 0000000..79e6025
--- /dev/null
+++ b/src/repositories/sqlite.rs
@@ -0,0 +1,164 @@
+use super::*;
+use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode, DbOAuthClient};
+use anyhow::Result;
+use std::sync::{Arc, Mutex};
+
+/// SQLite implementation of ClientRepository
+pub struct SqliteClientRepository {
+ database: Arc<Mutex<Database>>,
+}
+
+impl SqliteClientRepository {
+ pub fn new(database: Arc<Mutex<Database>>) -> Self {
+ Self { database }
+ }
+}
+
+impl ClientRepository for SqliteClientRepository {
+ fn get_client(&self, client_id: &str) -> Result<Option<DbOAuthClient>> {
+ let db = self.database.lock().unwrap();
+ db.get_oauth_client(client_id)
+ }
+
+ fn create_client(&self, client: &DbOAuthClient) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.create_oauth_client(client).map(|_| ())
+ }
+
+ fn update_client(&self, client: &DbOAuthClient) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.update_oauth_client(client)
+ }
+
+ fn delete_client(&self, client_id: &str) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.delete_oauth_client(client_id)
+ }
+
+ fn list_clients(&self) -> Result<Vec<DbOAuthClient>> {
+ let db = self.database.lock().unwrap();
+ db.list_oauth_clients()
+ }
+}
+
+/// SQLite implementation of AuthCodeRepository
+pub struct SqliteAuthCodeRepository {
+ database: Arc<Mutex<Database>>,
+}
+
+impl SqliteAuthCodeRepository {
+ pub fn new(database: Arc<Mutex<Database>>) -> Self {
+ Self { database }
+ }
+}
+
+impl AuthCodeRepository for SqliteAuthCodeRepository {
+ fn create_auth_code(&self, code: &DbAuthCode) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.create_auth_code(code).map(|_| ())
+ }
+
+ fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>> {
+ let db = self.database.lock().unwrap();
+ db.get_auth_code(code)
+ }
+
+ fn mark_auth_code_used(&self, code: &str) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.mark_auth_code_used(code)
+ }
+
+ fn cleanup_expired_codes(&self) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.cleanup_expired_codes().map(|_| ())
+ }
+}
+
+/// SQLite implementation of TokenRepository
+pub struct SqliteTokenRepository {
+ database: Arc<Mutex<Database>>,
+}
+
+impl SqliteTokenRepository {
+ pub fn new(database: Arc<Mutex<Database>>) -> Self {
+ Self { database }
+ }
+}
+
+impl TokenRepository for SqliteTokenRepository {
+ fn create_access_token(&self, token: &DbAccessToken) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.create_access_token(token).map(|_| ())
+ }
+
+ fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>> {
+ let db = self.database.lock().unwrap();
+ db.get_access_token(token_hash)
+ }
+
+ fn revoke_access_token(&self, token_hash: &str) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.revoke_access_token(token_hash)
+ }
+
+ fn cleanup_expired_tokens(&self) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.cleanup_expired_tokens().map(|_| ())
+ }
+}
+
+/// SQLite implementation of AuditRepository
+pub struct SqliteAuditRepository {
+ database: Arc<Mutex<Database>>,
+}
+
+impl SqliteAuditRepository {
+ pub fn new(database: Arc<Mutex<Database>>) -> Self {
+ Self { database }
+ }
+}
+
+impl AuditRepository for SqliteAuditRepository {
+ fn create_audit_log(&self, log: &DbAuditLog) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.create_audit_log(log).map(|_| ())
+ }
+
+ fn get_audit_logs(&self, limit: Option<i32>) -> Result<Vec<DbAuditLog>> {
+ let db = self.database.lock().unwrap();
+ db.get_audit_logs(limit.unwrap_or(100))
+ }
+
+ fn cleanup_old_audit_logs(&self, days: i32) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.cleanup_old_audit_logs(days).map(|_| ())
+ }
+}
+
+/// SQLite implementation of RateRepository
+pub struct SqliteRateRepository {
+ database: Arc<Mutex<Database>>,
+}
+
+impl SqliteRateRepository {
+ pub fn new(database: Arc<Mutex<Database>>) -> Self {
+ Self { database }
+ }
+}
+
+impl RateRepository for SqliteRateRepository {
+ fn increment_rate_limit(
+ &self,
+ identifier: &str,
+ endpoint: &str,
+ window_size: i32,
+ ) -> Result<i32> {
+ let db = self.database.lock().unwrap();
+ db.increment_rate_limit(identifier, endpoint, window_size)
+ }
+
+ fn cleanup_old_rate_limits(&self) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ db.cleanup_old_rate_limits()
+ }
+}
diff --git a/src/services/implementations.rs b/src/services/implementations.rs
new file mode 100644
index 0000000..ff03165
--- /dev/null
+++ b/src/services/implementations.rs
@@ -0,0 +1,217 @@
+use super::*;
+use crate::clients::parse_basic_auth;
+use crate::config::Config;
+use crate::database::DbAuditLog;
+use crate::keys::KeyManager;
+use crate::oauth::types::Claims;
+use crate::repositories::{AuditRepository, ClientRepository, RateRepository};
+use anyhow::{Result, anyhow};
+use chrono::Utc;
+use jsonwebtoken::{Algorithm, Header, encode};
+use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
+use std::time::{SystemTime, UNIX_EPOCH};
+use uuid::Uuid;
+
+/// Default implementation of ClientAuthenticator
+pub struct DefaultClientAuthenticator {
+ client_repository: Arc<dyn ClientRepository>,
+}
+
+impl DefaultClientAuthenticator {
+ pub fn new(client_repository: Arc<dyn ClientRepository>) -> Self {
+ Self { client_repository }
+ }
+
+ fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool {
+ match self.client_repository.get_client(client_id) {
+ Ok(Some(client)) => {
+ // Use constant-time comparison to prevent timing attacks
+ use subtle::ConstantTimeEq;
+ let expected_hash = client.client_secret_hash.as_bytes();
+ let provided_hash = self.hash_client_secret(client_secret);
+ expected_hash.ct_eq(provided_hash.as_bytes()).into()
+ }
+ _ => false,
+ }
+ }
+
+ fn hash_client_secret(&self, secret: &str) -> String {
+ use sha2::{Digest, Sha256};
+ let mut hasher = Sha256::new();
+ hasher.update(secret.as_bytes());
+ format!("{:x}", hasher.finalize())
+ }
+}
+
+impl ClientAuthenticator for DefaultClientAuthenticator {
+ fn authenticate(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(String, String), String> {
+ if let Some(auth_header) = auth_header {
+ // HTTP Basic Authentication (preferred method)
+ parse_basic_auth(auth_header).ok_or_else(|| "Invalid Authorization header".to_string())
+ } else {
+ // Form-based authentication (fallback)
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| "Missing client_id".to_string())?;
+ let client_secret = params
+ .get("client_secret")
+ .ok_or_else(|| "Missing client_secret".to_string())?;
+
+ if !self.authenticate_client(client_id, client_secret) {
+ return Err("Invalid client credentials".to_string());
+ }
+
+ Ok((client_id.clone(), client_secret.clone()))
+ }
+ }
+}
+
+/// Default implementation of RateLimiter
+pub struct DefaultRateLimiter {
+ rate_repository: Arc<dyn RateRepository>,
+ config: Config,
+}
+
+impl DefaultRateLimiter {
+ pub fn new(rate_repository: Arc<dyn RateRepository>, config: Config) -> Self {
+ Self {
+ rate_repository,
+ config,
+ }
+ }
+}
+
+impl RateLimiter for DefaultRateLimiter {
+ fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> {
+ let count = self
+ .rate_repository
+ .increment_rate_limit(identifier, endpoint, 1)?;
+
+ if count > self.config.rate_limit_requests_per_minute as i32 {
+ return Err(anyhow!("Rate limit exceeded"));
+ }
+
+ Ok(())
+ }
+}
+
+/// Default implementation of AuditLogger
+pub struct DefaultAuditLogger {
+ audit_repository: Arc<dyn AuditRepository>,
+ config: Config,
+}
+
+impl DefaultAuditLogger {
+ pub fn new(audit_repository: Arc<dyn AuditRepository>, config: Config) -> Self {
+ Self {
+ audit_repository,
+ config,
+ }
+ }
+}
+
+impl AuditLogger for DefaultAuditLogger {
+ fn log_event(
+ &self,
+ event_type: &str,
+ client_id: Option<&str>,
+ user_id: Option<&str>,
+ ip_address: Option<&str>,
+ success: bool,
+ details: Option<&str>,
+ ) -> Result<()> {
+ if !self.config.enable_audit_logging {
+ return Ok(());
+ }
+
+ let log = DbAuditLog {
+ id: 0,
+ event_type: event_type.to_string(),
+ client_id: client_id.map(|s| s.to_string()),
+ user_id: user_id.map(|s| s.to_string()),
+ ip_address: ip_address.map(|s| s.to_string()),
+ user_agent: None, // Could be passed in from HTTP layer
+ details: details.map(|s| s.to_string()),
+ created_at: Utc::now(),
+ success,
+ };
+
+ self.audit_repository.create_audit_log(&log)?;
+ Ok(())
+ }
+}
+
+/// Default implementation of TokenGenerator
+pub struct DefaultTokenGenerator {
+ key_manager: Arc<Mutex<KeyManager>>,
+ config: Config,
+}
+
+impl DefaultTokenGenerator {
+ pub fn new(key_manager: Arc<Mutex<KeyManager>>, config: Config) -> Self {
+ Self {
+ key_manager,
+ config,
+ }
+ }
+}
+
+impl TokenGenerator for DefaultTokenGenerator {
+ fn generate_access_token(
+ &self,
+ user_id: &str,
+ client_id: &str,
+ scope: &Option<String>,
+ token_id: &str,
+ ) -> Result<String, String> {
+ let mut key_manager = self.key_manager.lock().unwrap();
+
+ // Check if we need to rotate keys
+ if key_manager.should_rotate() {
+ if let Err(_) = key_manager.rotate_keys() {
+ return Err("Key rotation failed".to_string());
+ }
+ }
+
+ let current_key = key_manager
+ .get_current_key()
+ .ok_or_else(|| "No signing key available".to_string())?;
+
+ let now = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let claims = Claims {
+ sub: user_id.to_string(),
+ iss: self.config.issuer_url.clone(),
+ aud: client_id.to_string(),
+ exp: now + 3600,
+ iat: now,
+ scope: scope.clone(),
+ jti: Some(token_id.to_string()),
+ };
+
+ let mut header = Header::new(Algorithm::RS256);
+ header.kid = Some(current_key.kid.clone());
+
+ encode(&header, &claims, &current_key.encoding_key)
+ .map_err(|_| "Failed to generate token".to_string())
+ }
+
+ fn generate_refresh_token(
+ &self,
+ _client_id: &str,
+ _user_id: &str,
+ _scope: &Option<String>,
+ ) -> Result<String, String> {
+ // For now, return a simple UUID-based refresh token
+ // In production, this should be a proper JWT or encrypted token
+ Ok(Uuid::new_v4().to_string())
+ }
+}
diff --git a/src/services/mod.rs b/src/services/mod.rs
new file mode 100644
index 0000000..26d74e3
--- /dev/null
+++ b/src/services/mod.rs
@@ -0,0 +1,49 @@
+use anyhow::Result;
+use std::collections::HashMap;
+
+/// Service trait for client authentication
+pub trait ClientAuthenticator: Send + Sync {
+ fn authenticate(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(String, String), String>; // Returns (client_id, client_secret)
+}
+
+/// Service trait for rate limiting
+pub trait RateLimiter: Send + Sync {
+ fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()>;
+}
+
+/// Service trait for audit logging
+pub trait AuditLogger: Send + Sync {
+ fn log_event(
+ &self,
+ event_type: &str,
+ client_id: Option<&str>,
+ user_id: Option<&str>,
+ ip_address: Option<&str>,
+ success: bool,
+ details: Option<&str>,
+ ) -> Result<()>;
+}
+
+/// Service trait for token generation
+pub trait TokenGenerator: Send + Sync {
+ fn generate_access_token(
+ &self,
+ user_id: &str,
+ client_id: &str,
+ scope: &Option<String>,
+ token_id: &str,
+ ) -> Result<String, String>;
+
+ fn generate_refresh_token(
+ &self,
+ client_id: &str,
+ user_id: &str,
+ scope: &Option<String>,
+ ) -> Result<String, String>;
+}
+
+pub mod implementations;