summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-12 09:59:14 -0600
committermo khan <mo@mokhan.ca>2025-06-12 09:59:14 -0600
commit86b6f94b65cb3e821bab75439852afabfec78377 (patch)
treef5eaac815ce0c4cc0ac1ee8ab25fe6eecb8dc95d
parentc28b7088b6fad045060a52b6e1a2249e876090e3 (diff)
feat: start to add support for refresh tokensHEADmain
-rw-r--r--REFACTORING_PLAN.md289
-rw-r--r--src/database.rs121
-rw-r--r--src/oauth/server.rs205
3 files changed, 304 insertions, 311 deletions
diff --git a/REFACTORING_PLAN.md b/REFACTORING_PLAN.md
deleted file mode 100644
index ec24449..0000000
--- a/REFACTORING_PLAN.md
+++ /dev/null
@@ -1,289 +0,0 @@
-# OAuth2 STS SOLID Refactoring Plan
-
-## Overview
-This document outlines refactoring opportunities to better align with SOLID principles.
-
-## 1. Extract Grant Type Handlers (SRP + OCP)
-
-### Current Issue
-`OAuthServer::handle_token()` contains all grant type logic in one method.
-
-### Proposed Solution
-```rust
-trait GrantHandler {
- fn can_handle(&self, grant_type: &str) -> bool;
- fn handle_grant(
- &self,
- params: &HashMap<String, String>,
- auth_header: Option<&str>,
- ip_address: Option<String>,
- ) -> Result<String, String>;
-}
-
-struct AuthorizationCodeGrantHandler {
- client_authenticator: Arc<dyn ClientAuthenticator>,
- token_generator: Arc<dyn TokenGenerator>,
- code_repository: Arc<dyn AuthCodeRepository>,
- token_repository: Arc<dyn TokenRepository>,
- audit_logger: Arc<dyn AuditLogger>,
-}
-
-struct RefreshTokenGrantHandler {
- // Similar dependencies
-}
-
-struct GrantHandlerRegistry {
- handlers: Vec<Arc<dyn GrantHandler>>,
-}
-```
-
-### Benefits
-- ✅ Easy to add new grant types (Client Credentials, Device Code, etc.)
-- ✅ Each handler has single responsibility
-- ✅ Testable in isolation
-
-## 2. Create Repository Abstractions (DIP + SRP)
-
-### Current Issue
-Monolithic `Database` struct violates SRP and makes testing difficult.
-
-### Proposed Solution
-```rust
-trait ClientRepository {
- fn get_client(&self, client_id: &str) -> Result<Option<OAuthClient>>;
- fn authenticate_client(&self, client_id: &str, secret: &str) -> Result<bool>;
- fn is_redirect_uri_valid(&self, client_id: &str, uri: &str) -> bool;
-}
-
-trait TokenRepository {
- fn create_access_token(&self, token: &DbAccessToken) -> Result<()>;
- fn get_access_token(&self, token_hash: &str) -> Result<Option<DbAccessToken>>;
- fn revoke_access_token(&self, token_hash: &str) -> Result<()>;
-}
-
-trait AuthCodeRepository {
- fn create_auth_code(&self, code: &DbAuthCode) -> Result<()>;
- fn get_auth_code(&self, code: &str) -> Result<Option<DbAuthCode>>;
- fn mark_auth_code_used(&self, code: &str) -> Result<()>;
-}
-
-// SQLite implementations
-struct SqliteClientRepository {
- database: Arc<Mutex<Database>>,
-}
-
-struct SqliteTokenRepository {
- database: Arc<Mutex<Database>>,
-}
-```
-
-### Benefits
-- ✅ Easy to swap database backends (PostgreSQL, Redis, etc.)
-- ✅ Better testability with mock repositories
-- ✅ Clear separation of concerns
-
-## 3. Extract Authentication Strategy (OCP + SRP)
-
-### Current Issue
-Client authentication logic scattered throughout token handlers.
-
-### Proposed Solution
-```rust
-trait ClientAuthenticator {
- fn authenticate(
- &self,
- params: &HashMap<String, String>,
- auth_header: Option<&str>,
- ) -> Result<(String, String), String>; // Returns (client_id, client_secret)
-}
-
-struct BasicAuthenticator {
- client_repository: Arc<dyn ClientRepository>,
-}
-
-struct PostAuthenticator {
- client_repository: Arc<dyn ClientRepository>,
-}
-
-struct CompositeAuthenticator {
- authenticators: Vec<Arc<dyn ClientAuthenticator>>,
-}
-```
-
-### Benefits
-- ✅ Easy to add new authentication methods (JWT, mTLS, etc.)
-- ✅ Clear separation of authentication concerns
-- ✅ Configurable authentication strategies
-
-## 4. Extract Token Generation (SRP + OCP)
-
-### Current Issue
-Token generation logic embedded in `OAuthServer`.
-
-### Proposed Solution
-```rust
-trait TokenGenerator {
- fn generate_access_token(
- &self,
- user_id: &str,
- client_id: &str,
- scope: &Option<String>,
- token_id: &str,
- ) -> Result<String>;
-
- fn generate_refresh_token(
- &self,
- client_id: &str,
- user_id: &str,
- scope: &Option<String>,
- ) -> Result<String>;
-}
-
-struct JwtTokenGenerator {
- key_manager: Arc<Mutex<KeyManager>>,
- config: Config,
-}
-
-struct OpaqueTokenGenerator {
- // For opaque token implementation
-}
-```
-
-### Benefits
-- ✅ Easy to switch token formats (JWT vs opaque)
-- ✅ Configurable token generation strategies
-- ✅ Better testability
-
-## 5. Extract Cross-Cutting Concerns
-
-### Rate Limiting
-```rust
-trait RateLimiter {
- fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()>;
-}
-
-struct DatabaseRateLimiter {
- rate_repository: Arc<dyn RateRepository>,
- config: Config,
-}
-
-struct RedisRateLimiter {
- redis_client: redis::Client,
- config: Config,
-}
-```
-
-### Audit Logging
-```rust
-trait AuditLogger {
- fn log_event(&self, event: AuditEvent) -> Result<()>;
-}
-
-struct DatabaseAuditLogger {
- audit_repository: Arc<dyn AuditRepository>,
-}
-
-struct JsonFileAuditLogger {
- file_path: PathBuf,
-}
-```
-
-## 6. Refactor HTTP Layer (SRP)
-
-### Current Issue
-`Server` mixes HTTP protocol with OAuth business logic.
-
-### Proposed Solution
-```rust
-struct HttpServer {
- router: Router,
- config: Config,
-}
-
-struct Router {
- oauth_handler: Arc<OAuthHandler>,
- static_handler: Arc<StaticHandler>,
-}
-
-struct OAuthHandler {
- oauth_service: Arc<OAuthService>, // Renamed from OAuthServer
-}
-
-// Clean separation between HTTP concerns and OAuth business logic
-```
-
-## 7. Dependency Injection Container
-
-### Proposed Solution
-```rust
-struct ServiceContainer {
- // Repositories
- client_repository: Arc<dyn ClientRepository>,
- token_repository: Arc<dyn TokenRepository>,
- auth_code_repository: Arc<dyn AuthCodeRepository>,
-
- // Services
- client_authenticator: Arc<dyn ClientAuthenticator>,
- token_generator: Arc<dyn TokenGenerator>,
- rate_limiter: Arc<dyn RateLimiter>,
- audit_logger: Arc<dyn AuditLogger>,
-
- // Grant handlers
- grant_registry: Arc<GrantHandlerRegistry>,
-}
-
-impl ServiceContainer {
- fn new(config: &Config) -> Result<Self> {
- // Wire up all dependencies
- }
-}
-```
-
-## Implementation Strategy
-
-### Phase 1: Repository Extraction
-1. Create repository traits
-2. Move database methods to specific repositories
-3. Update `OAuthServer` to use repositories
-
-### Phase 2: Grant Handler Extraction
-1. Create `GrantHandler` trait
-2. Extract authorization code handler
-3. Extract refresh token handler
-4. Create registry
-
-### Phase 3: Cross-Cutting Concerns
-1. Extract rate limiting
-2. Extract audit logging
-3. Extract authentication
-
-### Phase 4: HTTP Layer Cleanup
-1. Separate HTTP protocol from business logic
-2. Create clean request/response handlers
-
-### Phase 5: Dependency Injection
-1. Create service container
-2. Wire up all dependencies
-3. Update main.rs to use container
-
-## Benefits of This Refactoring
-
-### Maintainability
-- ✅ Easier to understand and modify individual components
-- ✅ Clear separation of concerns
-- ✅ Reduced coupling between components
-
-### Extensibility
-- ✅ Easy to add new grant types
-- ✅ Easy to swap implementations (database, token format, etc.)
-- ✅ Configurable strategies for cross-cutting concerns
-
-### Testability
-- ✅ Each component can be tested in isolation
-- ✅ Easy to create mock implementations
-- ✅ Better unit test coverage
-
-### Production Readiness
-- ✅ Easy to scale different components independently
-- ✅ Better observability and monitoring capabilities
-- ✅ More flexible deployment options \ No newline at end of file
diff --git a/src/database.rs b/src/database.rs
index 178eee3..0611efd 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -342,6 +342,127 @@ impl Database {
Ok(())
}
+ // Refresh Token operations
+ pub fn create_refresh_token(&self, token: &DbRefreshToken) -> Result<i64> {
+ let mut stmt = self.conn.prepare(
+ "INSERT INTO refresh_tokens
+ (token_id, access_token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash)
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
+ )?;
+
+ let id = stmt.insert(params![
+ token.token_id,
+ token.access_token_id,
+ token.client_id,
+ token.user_id,
+ token.scope,
+ token.expires_at.to_rfc3339(),
+ token.created_at.to_rfc3339(),
+ token.is_revoked,
+ token.token_hash
+ ])?;
+
+ Ok(id)
+ }
+
+ pub fn get_refresh_token(&self, token_hash: &str) -> Result<Option<DbRefreshToken>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at,
+ created_at, is_revoked, token_hash
+ FROM refresh_tokens WHERE token_hash = ?1"
+ )?;
+
+ let token = stmt.query_row([token_hash], |row| {
+ Ok(DbRefreshToken {
+ id: row.get(0)?,
+ token_id: row.get(1)?,
+ access_token_id: row.get(2)?,
+ client_id: row.get(3)?,
+ user_id: row.get(4)?,
+ scope: row.get(5)?,
+ expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 6,
+ "expires_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
+ .with_timezone(&Utc),
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?)
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 7,
+ "created_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
+ .with_timezone(&Utc),
+ is_revoked: row.get(8)?,
+ token_hash: row.get(9)?,
+ })
+ });
+
+ match token {
+ Ok(token) => Ok(Some(token)),
+ Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
+ Err(e) => Err(e.into()),
+ }
+ }
+
+ pub fn revoke_refresh_token(&self, token_hash: &str) -> Result<()> {
+ self.conn.execute(
+ "UPDATE refresh_tokens SET is_revoked = 1 WHERE token_hash = ?1",
+ [token_hash],
+ )?;
+ Ok(())
+ }
+
+ pub fn get_refresh_token_by_access_token(&self, access_token_id: i64) -> Result<Option<DbRefreshToken>> {
+ let mut stmt = self.conn.prepare(
+ "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at,
+ created_at, is_revoked, token_hash
+ FROM refresh_tokens WHERE access_token_id = ?1 AND is_revoked = 0"
+ )?;
+
+ let token = stmt.query_row([access_token_id], |row| {
+ Ok(DbRefreshToken {
+ id: row.get(0)?,
+ token_id: row.get(1)?,
+ access_token_id: row.get(2)?,
+ client_id: row.get(3)?,
+ user_id: row.get(4)?,
+ scope: row.get(5)?,
+ expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?)
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 6,
+ "expires_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
+ .with_timezone(&Utc),
+ created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?)
+ .map_err(|_| {
+ rusqlite::Error::InvalidColumnType(
+ 7,
+ "created_at".to_string(),
+ rusqlite::types::Type::Text,
+ )
+ })?
+ .with_timezone(&Utc),
+ is_revoked: row.get(8)?,
+ token_hash: row.get(9)?,
+ })
+ });
+
+ match token {
+ Ok(token) => Ok(Some(token)),
+ Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
+ Err(e) => Err(e.into()),
+ }
+ }
+
// RSA Key operations
pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result<i64> {
let mut stmt = self.conn.prepare(
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 37c3cbc..a951f7c 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -393,6 +393,21 @@ impl OAuthServer {
}
}
+ // Validate redirect_uri matches authorization request
+ if let Some(redirect_uri_param) = params.get("redirect_uri") {
+ if redirect_uri_param != &auth_code.redirect_uri {
+ self.audit_log(
+ "token_redirect_uri_mismatch",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "redirect_uri mismatch"));
+ }
+ }
+
// Generate tokens
let token_id = Uuid::new_v4().to_string();
let access_token = self.generate_access_token(
@@ -401,10 +416,8 @@ impl OAuthServer {
&auth_code.scope,
&token_id,
)?;
- let refresh_token =
- self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?;
- // Store token in database for revocation/introspection
+ // Store access token in database for revocation/introspection
let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes()));
let db_access_token = DbAccessToken {
id: 0,
@@ -415,15 +428,26 @@ impl OAuthServer {
expires_at: Utc::now() + Duration::hours(1),
created_at: Utc::now(),
is_revoked: false,
- token_hash,
+ token_hash: token_hash.clone(),
};
- {
+ let access_token_db_id = {
let db = self.database.lock().unwrap();
- if let Err(_) = db.create_access_token(&db_access_token) {
- return Err(self.error_response("server_error", "Failed to store access token"));
+ match db.create_access_token(&db_access_token) {
+ Ok(id) => id,
+ Err(_) => {
+ return Err(self.error_response("server_error", "Failed to store access token"));
+ }
}
- }
+ };
+
+ // Generate and store refresh token
+ let refresh_token = self.generate_refresh_token(
+ access_token_db_id,
+ &client_id,
+ &auth_code.user_id,
+ &auth_code.scope
+ )?;
let token_response = TokenResponse {
access_token,
@@ -452,7 +476,7 @@ impl OAuthServer {
auth_header: Option<&str>,
ip_address: Option<String>,
) -> Result<String, String> {
- let _refresh_token = params
+ let refresh_token_str = params
.get("refresh_token")
.ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?;
@@ -471,6 +495,19 @@ impl OAuthServer {
(client_id.clone(), client_secret.clone())
};
+ // Rate limiting check
+ if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") {
+ self.audit_log(
+ "refresh_rate_limited",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
+ return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
+ }
+
// Authenticate the client
{
let mut client_manager = self.client_manager.lock().unwrap();
@@ -487,25 +524,125 @@ impl OAuthServer {
}
}
- // Validate refresh token (implementation would verify token and get user info)
- // For now, return a simple refresh token response
+ // Validate refresh token
+ let refresh_token_hash = format!("{:x}", Sha256::digest(refresh_token_str.as_bytes()));
+ let refresh_token = {
+ let db = self.database.lock().unwrap();
+ match db.get_refresh_token(&refresh_token_hash) {
+ Ok(Some(token)) => token,
+ Ok(None) => {
+ self.audit_log(
+ "refresh_invalid_token",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "Invalid refresh token"));
+ }
+ Err(_) => {
+ return Err(self.error_response("server_error", "Internal server error"));
+ }
+ }
+ };
+
+ // Validate refresh token hasn't expired and isn't revoked
+ if refresh_token.is_revoked {
+ self.audit_log(
+ "refresh_token_revoked",
+ Some(&client_id),
+ Some(&refresh_token.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "Refresh token revoked"));
+ }
+
+ if Utc::now() > refresh_token.expires_at {
+ self.audit_log(
+ "refresh_token_expired",
+ Some(&client_id),
+ Some(&refresh_token.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "Refresh token expired"));
+ }
+
+ if refresh_token.client_id != client_id {
+ self.audit_log(
+ "refresh_client_mismatch",
+ Some(&client_id),
+ Some(&refresh_token.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "Client ID mismatch"));
+ }
+
+ // Revoke the old refresh token (optional but recommended security practice)
+ {
+ let db = self.database.lock().unwrap();
+ let _ = db.revoke_refresh_token(&refresh_token_hash);
+ }
+
+ // Generate new tokens
let new_token_id = Uuid::new_v4().to_string();
- let access_token =
- self.generate_access_token("test_user", &client_id, &None, &new_token_id)?;
- let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?;
+ let access_token = self.generate_access_token(
+ &refresh_token.user_id,
+ &client_id,
+ &refresh_token.scope,
+ &new_token_id,
+ )?;
+
+ // Store new access token
+ let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes()));
+ let db_access_token = DbAccessToken {
+ id: 0,
+ token_id: new_token_id.clone(),
+ client_id: client_id.clone(),
+ user_id: refresh_token.user_id.clone(),
+ scope: refresh_token.scope.clone(),
+ expires_at: Utc::now() + Duration::hours(1),
+ created_at: Utc::now(),
+ is_revoked: false,
+ token_hash: token_hash.clone(),
+ };
+
+ let access_token_db_id = {
+ let db = self.database.lock().unwrap();
+ match db.create_access_token(&db_access_token) {
+ Ok(id) => id,
+ Err(_) => {
+ return Err(self.error_response("server_error", "Failed to store access token"));
+ }
+ }
+ };
+
+ // Generate new refresh token
+ let new_refresh_token = self.generate_refresh_token(
+ access_token_db_id,
+ &client_id,
+ &refresh_token.user_id,
+ &refresh_token.scope,
+ )?;
let token_response = TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: 3600,
refresh_token: Some(new_refresh_token),
- scope: None,
+ scope: refresh_token.scope.clone(),
};
self.audit_log(
"refresh_success",
Some(&client_id),
- Some("test_user"),
+ Some(&refresh_token.user_id),
ip_address.as_deref(),
true,
None,
@@ -675,13 +812,37 @@ impl OAuthServer {
fn generate_refresh_token(
&self,
- _client_id: &str,
- _user_id: &str,
- _scope: &Option<String>,
+ access_token_id: i64,
+ client_id: &str,
+ user_id: &str,
+ scope: &Option<String>,
) -> Result<String, String> {
- // For now, return a simple UUID-based refresh token
- // In production, this should be a proper JWT or encrypted token
- Ok(Uuid::new_v4().to_string())
+ use crate::database::DbRefreshToken;
+
+ let refresh_token = Uuid::new_v4().to_string();
+ let token_hash = format!("{:x}", Sha256::digest(refresh_token.as_bytes()));
+
+ let db_refresh_token = DbRefreshToken {
+ id: 0,
+ token_id: Uuid::new_v4().to_string(),
+ access_token_id,
+ client_id: client_id.to_string(),
+ user_id: user_id.to_string(),
+ scope: scope.clone(),
+ expires_at: Utc::now() + Duration::days(30), // 30 day expiration for refresh tokens
+ created_at: Utc::now(),
+ is_revoked: false,
+ token_hash: token_hash.clone(),
+ };
+
+ {
+ let db = self.database.lock().unwrap();
+ if let Err(_) = db.create_refresh_token(&db_refresh_token) {
+ return Err(self.error_response("server_error", "Failed to store refresh token"));
+ }
+ }
+
+ Ok(refresh_token)
}
fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> {