diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-12 09:59:14 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-12 09:59:14 -0600 |
| commit | 86b6f94b65cb3e821bab75439852afabfec78377 (patch) | |
| tree | f5eaac815ce0c4cc0ac1ee8ab25fe6eecb8dc95d | |
| parent | c28b7088b6fad045060a52b6e1a2249e876090e3 (diff) | |
| -rw-r--r-- | REFACTORING_PLAN.md | 289 | ||||
| -rw-r--r-- | src/database.rs | 121 | ||||
| -rw-r--r-- | src/oauth/server.rs | 205 |
3 files changed, 304 insertions, 311 deletions
diff --git a/REFACTORING_PLAN.md b/REFACTORING_PLAN.md deleted file mode 100644 index ec24449..0000000 --- a/REFACTORING_PLAN.md +++ /dev/null @@ -1,289 +0,0 @@ -# OAuth2 STS SOLID Refactoring Plan - -## Overview -This document outlines refactoring opportunities to better align with SOLID principles. - -## 1. Extract Grant Type Handlers (SRP + OCP) - -### Current Issue -`OAuthServer::handle_token()` contains all grant type logic in one method. - -### Proposed Solution -```rust -trait GrantHandler { - fn can_handle(&self, grant_type: &str) -> bool; - fn handle_grant( - &self, - params: &HashMap<String, String>, - auth_header: Option<&str>, - ip_address: Option<String>, - ) -> Result<String, String>; -} - -struct AuthorizationCodeGrantHandler { - client_authenticator: Arc<dyn ClientAuthenticator>, - token_generator: Arc<dyn TokenGenerator>, - code_repository: Arc<dyn AuthCodeRepository>, - token_repository: Arc<dyn TokenRepository>, - audit_logger: Arc<dyn AuditLogger>, -} - -struct RefreshTokenGrantHandler { - // Similar dependencies -} - -struct GrantHandlerRegistry { - handlers: Vec<Arc<dyn GrantHandler>>, -} -``` - -### Benefits -- ✅ Easy to add new grant types (Client Credentials, Device Code, etc.) -- ✅ Each handler has single responsibility -- ✅ Testable in isolation - -## 2. Create Repository Abstractions (DIP + SRP) - -### Current Issue -Monolithic `Database` struct violates SRP and makes testing difficult. - -### Proposed Solution -```rust -trait ClientRepository { - fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>; - fn authenticate_client(&self, client_id: &str, secret: &str) -> Result<bool>; - fn is_redirect_uri_valid(&self, client_id: &str, uri: &str) -> bool; -} - -trait TokenRepository { - fn create_access_token(&self, token: &DbAccessToken) -> Result<()>; - fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>>; - fn revoke_access_token(&self, token_hash: &str) -> Result<()>; -} - -trait AuthCodeRepository { - fn create_auth_code(&self, code: &DbAuthCode) -> Result<()>; - fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>>; - fn mark_auth_code_used(&self, code: &str) -> Result<()>; -} - -// SQLite implementations -struct SqliteClientRepository { - database: Arc<Mutex<Database>>, -} - -struct SqliteTokenRepository { - database: Arc<Mutex<Database>>, -} -``` - -### Benefits -- ✅ Easy to swap database backends (PostgreSQL, Redis, etc.) -- ✅ Better testability with mock repositories -- ✅ Clear separation of concerns - -## 3. Extract Authentication Strategy (OCP + SRP) - -### Current Issue -Client authentication logic scattered throughout token handlers. - -### Proposed Solution -```rust -trait ClientAuthenticator { - fn authenticate( - &self, - params: &HashMap<String, String>, - auth_header: Option<&str>, - ) -> Result<(String, String), String>; // Returns (client_id, client_secret) -} - -struct BasicAuthenticator { - client_repository: Arc<dyn ClientRepository>, -} - -struct PostAuthenticator { - client_repository: Arc<dyn ClientRepository>, -} - -struct CompositeAuthenticator { - authenticators: Vec<Arc<dyn ClientAuthenticator>>, -} -``` - -### Benefits -- ✅ Easy to add new authentication methods (JWT, mTLS, etc.) -- ✅ Clear separation of authentication concerns -- ✅ Configurable authentication strategies - -## 4. Extract Token Generation (SRP + OCP) - -### Current Issue -Token generation logic embedded in `OAuthServer`. - -### Proposed Solution -```rust -trait TokenGenerator { - fn generate_access_token( - &self, - user_id: &str, - client_id: &str, - scope: &Option<String>, - token_id: &str, - ) -> Result<String>; - - fn generate_refresh_token( - &self, - client_id: &str, - user_id: &str, - scope: &Option<String>, - ) -> Result<String>; -} - -struct JwtTokenGenerator { - key_manager: Arc<Mutex<KeyManager>>, - config: Config, -} - -struct OpaqueTokenGenerator { - // For opaque token implementation -} -``` - -### Benefits -- ✅ Easy to switch token formats (JWT vs opaque) -- ✅ Configurable token generation strategies -- ✅ Better testability - -## 5. Extract Cross-Cutting Concerns - -### Rate Limiting -```rust -trait RateLimiter { - fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()>; -} - -struct DatabaseRateLimiter { - rate_repository: Arc<dyn RateRepository>, - config: Config, -} - -struct RedisRateLimiter { - redis_client: redis::Client, - config: Config, -} -``` - -### Audit Logging -```rust -trait AuditLogger { - fn log_event(&self, event: AuditEvent) -> Result<()>; -} - -struct DatabaseAuditLogger { - audit_repository: Arc<dyn AuditRepository>, -} - -struct JsonFileAuditLogger { - file_path: PathBuf, -} -``` - -## 6. Refactor HTTP Layer (SRP) - -### Current Issue -`Server` mixes HTTP protocol with OAuth business logic. - -### Proposed Solution -```rust -struct HttpServer { - router: Router, - config: Config, -} - -struct Router { - oauth_handler: Arc<OAuthHandler>, - static_handler: Arc<StaticHandler>, -} - -struct OAuthHandler { - oauth_service: Arc<OAuthService>, // Renamed from OAuthServer -} - -// Clean separation between HTTP concerns and OAuth business logic -``` - -## 7. Dependency Injection Container - -### Proposed Solution -```rust -struct ServiceContainer { - // Repositories - client_repository: Arc<dyn ClientRepository>, - token_repository: Arc<dyn TokenRepository>, - auth_code_repository: Arc<dyn AuthCodeRepository>, - - // Services - client_authenticator: Arc<dyn ClientAuthenticator>, - token_generator: Arc<dyn TokenGenerator>, - rate_limiter: Arc<dyn RateLimiter>, - audit_logger: Arc<dyn AuditLogger>, - - // Grant handlers - grant_registry: Arc<GrantHandlerRegistry>, -} - -impl ServiceContainer { - fn new(config: &Config) -> Result<Self> { - // Wire up all dependencies - } -} -``` - -## Implementation Strategy - -### Phase 1: Repository Extraction -1. Create repository traits -2. Move database methods to specific repositories -3. Update `OAuthServer` to use repositories - -### Phase 2: Grant Handler Extraction -1. Create `GrantHandler` trait -2. Extract authorization code handler -3. Extract refresh token handler -4. Create registry - -### Phase 3: Cross-Cutting Concerns -1. Extract rate limiting -2. Extract audit logging -3. Extract authentication - -### Phase 4: HTTP Layer Cleanup -1. Separate HTTP protocol from business logic -2. Create clean request/response handlers - -### Phase 5: Dependency Injection -1. Create service container -2. Wire up all dependencies -3. Update main.rs to use container - -## Benefits of This Refactoring - -### Maintainability -- ✅ Easier to understand and modify individual components -- ✅ Clear separation of concerns -- ✅ Reduced coupling between components - -### Extensibility -- ✅ Easy to add new grant types -- ✅ Easy to swap implementations (database, token format, etc.) -- ✅ Configurable strategies for cross-cutting concerns - -### Testability -- ✅ Each component can be tested in isolation -- ✅ Easy to create mock implementations -- ✅ Better unit test coverage - -### Production Readiness -- ✅ Easy to scale different components independently -- ✅ Better observability and monitoring capabilities -- ✅ More flexible deployment options
\ No newline at end of file diff --git a/src/database.rs b/src/database.rs index 178eee3..0611efd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -342,6 +342,127 @@ impl Database { Ok(()) } + // Refresh Token operations + pub fn create_refresh_token(&self, token: &DbRefreshToken) -> Result<i64> { + let mut stmt = self.conn.prepare( + "INSERT INTO refresh_tokens + (token_id, access_token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + )?; + + let id = stmt.insert(params![ + token.token_id, + token.access_token_id, + token.client_id, + token.user_id, + token.scope, + token.expires_at.to_rfc3339(), + token.created_at.to_rfc3339(), + token.is_revoked, + token.token_hash + ])?; + + Ok(id) + } + + pub fn get_refresh_token(&self, token_hash: &str) -> Result<Option<DbRefreshToken>> { + let mut stmt = self.conn.prepare( + "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at, + created_at, is_revoked, token_hash + FROM refresh_tokens WHERE token_hash = ?1" + )?; + + let token = stmt.query_row([token_hash], |row| { + Ok(DbRefreshToken { + id: row.get(0)?, + token_id: row.get(1)?, + access_token_id: row.get(2)?, + client_id: row.get(3)?, + user_id: row.get(4)?, + scope: row.get(5)?, + expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 6, + "expires_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 7, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + is_revoked: row.get(8)?, + token_hash: row.get(9)?, + }) + }); + + match token { + Ok(token) => Ok(Some(token)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + + pub fn revoke_refresh_token(&self, token_hash: &str) -> Result<()> { + self.conn.execute( + "UPDATE refresh_tokens SET is_revoked = 1 WHERE token_hash = ?1", + [token_hash], + )?; + Ok(()) + } + + pub fn get_refresh_token_by_access_token(&self, access_token_id: i64) -> Result<Option<DbRefreshToken>> { + let mut stmt = self.conn.prepare( + "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at, + created_at, is_revoked, token_hash + FROM refresh_tokens WHERE access_token_id = ?1 AND is_revoked = 0" + )?; + + let token = stmt.query_row([access_token_id], |row| { + Ok(DbRefreshToken { + id: row.get(0)?, + token_id: row.get(1)?, + access_token_id: row.get(2)?, + client_id: row.get(3)?, + user_id: row.get(4)?, + scope: row.get(5)?, + expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 6, + "expires_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 7, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + is_revoked: row.get(8)?, + token_hash: row.get(9)?, + }) + }); + + match token { + Ok(token) => Ok(Some(token)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + // RSA Key operations pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result<i64> { let mut stmt = self.conn.prepare( diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 37c3cbc..a951f7c 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -393,6 +393,21 @@ impl OAuthServer { } } + // Validate redirect_uri matches authorization request + if let Some(redirect_uri_param) = params.get("redirect_uri") { + if redirect_uri_param != &auth_code.redirect_uri { + self.audit_log( + "token_redirect_uri_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "redirect_uri mismatch")); + } + } + // Generate tokens let token_id = Uuid::new_v4().to_string(); let access_token = self.generate_access_token( @@ -401,10 +416,8 @@ impl OAuthServer { &auth_code.scope, &token_id, )?; - let refresh_token = - self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; - // Store token in database for revocation/introspection + // Store access token in database for revocation/introspection let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); let db_access_token = DbAccessToken { id: 0, @@ -415,15 +428,26 @@ impl OAuthServer { expires_at: Utc::now() + Duration::hours(1), created_at: Utc::now(), is_revoked: false, - token_hash, + token_hash: token_hash.clone(), }; - { + let access_token_db_id = { let db = self.database.lock().unwrap(); - if let Err(_) = db.create_access_token(&db_access_token) { - return Err(self.error_response("server_error", "Failed to store access token")); + match db.create_access_token(&db_access_token) { + Ok(id) => id, + Err(_) => { + return Err(self.error_response("server_error", "Failed to store access token")); + } } - } + }; + + // Generate and store refresh token + let refresh_token = self.generate_refresh_token( + access_token_db_id, + &client_id, + &auth_code.user_id, + &auth_code.scope + )?; let token_response = TokenResponse { access_token, @@ -452,7 +476,7 @@ impl OAuthServer { auth_header: Option<&str>, ip_address: Option<String>, ) -> Result<String, String> { - let _refresh_token = params + let refresh_token_str = params .get("refresh_token") .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; @@ -471,6 +495,19 @@ impl OAuthServer { (client_id.clone(), client_secret.clone()) }; + // Rate limiting check + if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") { + self.audit_log( + "refresh_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + // Authenticate the client { let mut client_manager = self.client_manager.lock().unwrap(); @@ -487,25 +524,125 @@ impl OAuthServer { } } - // Validate refresh token (implementation would verify token and get user info) - // For now, return a simple refresh token response + // Validate refresh token + let refresh_token_hash = format!("{:x}", Sha256::digest(refresh_token_str.as_bytes())); + let refresh_token = { + let db = self.database.lock().unwrap(); + match db.get_refresh_token(&refresh_token_hash) { + Ok(Some(token)) => token, + Ok(None) => { + self.audit_log( + "refresh_invalid_token", + Some(&client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Invalid refresh token")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + } + }; + + // Validate refresh token hasn't expired and isn't revoked + if refresh_token.is_revoked { + self.audit_log( + "refresh_token_revoked", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Refresh token revoked")); + } + + if Utc::now() > refresh_token.expires_at { + self.audit_log( + "refresh_token_expired", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Refresh token expired")); + } + + if refresh_token.client_id != client_id { + self.audit_log( + "refresh_client_mismatch", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + // Revoke the old refresh token (optional but recommended security practice) + { + let db = self.database.lock().unwrap(); + let _ = db.revoke_refresh_token(&refresh_token_hash); + } + + // Generate new tokens let new_token_id = Uuid::new_v4().to_string(); - let access_token = - self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; - let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?; + let access_token = self.generate_access_token( + &refresh_token.user_id, + &client_id, + &refresh_token.scope, + &new_token_id, + )?; + + // Store new access token + let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); + let db_access_token = DbAccessToken { + id: 0, + token_id: new_token_id.clone(), + client_id: client_id.clone(), + user_id: refresh_token.user_id.clone(), + scope: refresh_token.scope.clone(), + expires_at: Utc::now() + Duration::hours(1), + created_at: Utc::now(), + is_revoked: false, + token_hash: token_hash.clone(), + }; + + let access_token_db_id = { + let db = self.database.lock().unwrap(); + match db.create_access_token(&db_access_token) { + Ok(id) => id, + Err(_) => { + return Err(self.error_response("server_error", "Failed to store access token")); + } + } + }; + + // Generate new refresh token + let new_refresh_token = self.generate_refresh_token( + access_token_db_id, + &client_id, + &refresh_token.user_id, + &refresh_token.scope, + )?; let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, refresh_token: Some(new_refresh_token), - scope: None, + scope: refresh_token.scope.clone(), }; self.audit_log( "refresh_success", Some(&client_id), - Some("test_user"), + Some(&refresh_token.user_id), ip_address.as_deref(), true, None, @@ -675,13 +812,37 @@ impl OAuthServer { fn generate_refresh_token( &self, - _client_id: &str, - _user_id: &str, - _scope: &Option<String>, + access_token_id: i64, + client_id: &str, + user_id: &str, + scope: &Option<String>, ) -> Result<String, String> { - // For now, return a simple UUID-based refresh token - // In production, this should be a proper JWT or encrypted token - Ok(Uuid::new_v4().to_string()) + use crate::database::DbRefreshToken; + + let refresh_token = Uuid::new_v4().to_string(); + let token_hash = format!("{:x}", Sha256::digest(refresh_token.as_bytes())); + + let db_refresh_token = DbRefreshToken { + id: 0, + token_id: Uuid::new_v4().to_string(), + access_token_id, + client_id: client_id.to_string(), + user_id: user_id.to_string(), + scope: scope.clone(), + expires_at: Utc::now() + Duration::days(30), // 30 day expiration for refresh tokens + created_at: Utc::now(), + is_revoked: false, + token_hash: token_hash.clone(), + }; + + { + let db = self.database.lock().unwrap(); + if let Err(_) = db.create_refresh_token(&db_refresh_token) { + return Err(self.error_response("server_error", "Failed to store refresh token")); + } + } + + Ok(refresh_token) } fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { |
