summaryrefslogtreecommitdiff
path: root/src/oauth
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 17:11:39 -0600
committermo khan <mo@mokhan.ca>2025-06-11 17:11:39 -0600
commit5ffc9b007ccbd8a4510b58de72aaee53291d7973 (patch)
treef696a2a7599926d402c5456c434bd87e5e325c3a /src/oauth
parentdbd3c780f27bd5bee23adf6e280b84d669230e0d (diff)
refactor: apply SOLID principles
Diffstat (limited to 'src/oauth')
-rw-r--r--src/oauth/mod.rs4
-rw-r--r--src/oauth/pkce.rs18
-rw-r--r--src/oauth/server.rs8
-rw-r--r--src/oauth/service.rs566
-rw-r--r--src/oauth/types.rs17
5 files changed, 600 insertions, 13 deletions
diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs
index 4b18bb3..b2d46fa 100644
--- a/src/oauth/mod.rs
+++ b/src/oauth/mod.rs
@@ -1,9 +1,11 @@
pub mod pkce;
pub mod server;
+pub mod service;
pub mod types;
pub use pkce::{
- generate_code_challenge, generate_code_verifier, verify_code_challenge, CodeChallengeMethod,
+ CodeChallengeMethod, generate_code_challenge, generate_code_verifier, verify_code_challenge,
};
pub use server::OAuthServer;
+pub use service::OAuthService;
pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse};
diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs
index 406d364..0dfc1f8 100644
--- a/src/oauth/pkce.rs
+++ b/src/oauth/pkce.rs
@@ -1,5 +1,5 @@
-use anyhow::{anyhow, Result};
-use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
+use anyhow::{Result, anyhow};
+use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use sha2::{Digest, Sha256};
#[derive(Debug, Clone, PartialEq)]
@@ -124,12 +124,14 @@ mod tests {
assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err());
// Invalid characters
- assert!(verify_code_challenge(
- "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!",
- "challenge",
- &CodeChallengeMethod::Plain
- )
- .is_err());
+ assert!(
+ verify_code_challenge(
+ "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!",
+ "challenge",
+ &CodeChallengeMethod::Plain
+ )
+ .is_err()
+ );
}
#[test]
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 7fd8b9c..37c3cbc 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -1,12 +1,12 @@
-use crate::clients::{parse_basic_auth, ClientManager};
+use crate::clients::{ClientManager, parse_basic_auth};
use crate::config::Config;
use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode};
use crate::keys::KeyManager;
-use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod};
+use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge};
use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse};
-use anyhow::{anyhow, Result};
+use anyhow::{Result, anyhow};
use chrono::{Duration, Utc};
-use jsonwebtoken::{encode, Algorithm, Header};
+use jsonwebtoken::{Algorithm, Header, encode};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
diff --git a/src/oauth/service.rs b/src/oauth/service.rs
new file mode 100644
index 0000000..1b4eb49
--- /dev/null
+++ b/src/oauth/service.rs
@@ -0,0 +1,566 @@
+use crate::container::ServiceContainer;
+use crate::database::{DbAccessToken, DbAuthCode};
+use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge};
+use crate::oauth::types::{ErrorResponse, TokenIntrospectionResponse, TokenResponse};
+use anyhow::Result;
+use chrono::{Duration, Utc};
+use sha2::{Digest, Sha256};
+use std::collections::HashMap;
+use std::sync::Arc;
+use url::Url;
+use uuid::Uuid;
+
+/// Refactored OAuth service using dependency injection
+pub struct OAuthService {
+ container: Arc<ServiceContainer>,
+}
+
+impl OAuthService {
+ pub fn new(container: Arc<ServiceContainer>) -> Self {
+ Self { container }
+ }
+
+ pub fn get_jwks(&self) -> String {
+ self.container.get_jwks()
+ }
+
+ pub fn handle_authorize(
+ &self,
+ params: &HashMap<String, String>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+
+ let redirect_uri = params
+ .get("redirect_uri")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?;
+
+ let response_type = params
+ .get("response_type")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?;
+
+ // Rate limiting check
+ if let Err(e) = self
+ .container
+ .rate_limiter
+ .check_rate_limit(&format!("client:{}", client_id), "/authorize")
+ {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_rate_limited",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
+ return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
+ }
+
+ // Validate client exists
+ let client = match self.container.client_repository.get_client(client_id) {
+ Ok(Some(client)) => client,
+ Ok(None) => {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_invalid_client",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_client", "Invalid client_id"));
+ }
+ Err(_) => {
+ return Err(self.error_response("server_error", "Internal server error"));
+ }
+ };
+
+ // Validate redirect URI
+ let redirect_uris: Vec<String> =
+ serde_json::from_str(&client.redirect_uris).unwrap_or_else(|_| vec![]);
+
+ if !redirect_uris.contains(redirect_uri) {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_invalid_redirect_uri",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(redirect_uri),
+ );
+ return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
+ }
+
+ // Validate requested scopes
+ let scope = params.get("scope").cloned();
+ if let Some(ref scope_str) = scope {
+ let client_scopes: Vec<&str> = client.scopes.split_whitespace().collect();
+ let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
+
+ for requested_scope in &requested_scopes {
+ if !client_scopes.contains(requested_scope) {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_invalid_scope",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ scope.as_deref(),
+ );
+ return Err(self.error_response("invalid_scope", "Invalid scope"));
+ }
+ }
+ }
+
+ if response_type != "code" {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_unsupported_response_type",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(response_type),
+ );
+ return Err(self.error_response(
+ "unsupported_response_type",
+ "Only code response type supported",
+ ));
+ }
+
+ // PKCE validation (RFC 7636)
+ let code_challenge = params.get("code_challenge");
+ let code_challenge_method = params
+ .get("code_challenge_method")
+ .map(|method| CodeChallengeMethod::from_str(method))
+ .transpose()
+ .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?;
+
+ // For public clients, PKCE is required
+ if client.client_id.starts_with("public_") && code_challenge.is_none() {
+ let _ = self.container.audit_logger.log_event(
+ "authorize_missing_pkce",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_request", "PKCE required for public clients"));
+ }
+
+ let code = Uuid::new_v4().to_string();
+ let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration
+
+ let db_auth_code = DbAuthCode {
+ id: 0, // Will be set by database
+ code: code.clone(),
+ client_id: client_id.clone(),
+ user_id: "test_user".to_string(), // In a real implementation, this would come from authentication
+ redirect_uri: redirect_uri.clone(),
+ scope: scope.clone(),
+ expires_at,
+ created_at: Utc::now(),
+ is_used: false,
+ code_challenge: code_challenge.cloned(),
+ code_challenge_method: code_challenge_method
+ .as_ref()
+ .map(|m| m.as_str().to_string()),
+ };
+
+ // Save to database
+ if let Err(_) = self
+ .container
+ .auth_code_repository
+ .create_auth_code(&db_auth_code)
+ {
+ return Err(self.error_response("server_error", "Failed to create authorization code"));
+ }
+
+ let mut redirect_url = Url::parse(redirect_uri)
+ .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?;
+
+ redirect_url.query_pairs_mut().append_pair("code", &code);
+
+ if let Some(state) = params.get("state") {
+ redirect_url.query_pairs_mut().append_pair("state", state);
+ }
+
+ let _ = self.container.audit_logger.log_event(
+ "authorize_success",
+ Some(client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
+
+ Ok(redirect_url.to_string())
+ }
+
+ pub fn handle_token(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let grant_type = params
+ .get("grant_type")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
+
+ match grant_type.as_str() {
+ "authorization_code" => {
+ self.handle_authorization_code_grant(params, auth_header, ip_address)
+ }
+ "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address),
+ _ => {
+ let _ = self.container.audit_logger.log_event(
+ "token_unsupported_grant_type",
+ None,
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(grant_type),
+ );
+ Err(self.error_response("unsupported_grant_type", "Unsupported grant type"))
+ }
+ }
+ }
+
+ fn handle_authorization_code_grant(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let code = params
+ .get("code")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?;
+
+ // Client authentication using injected service
+ let (client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Rate limiting check
+ if let Err(e) = self
+ .container
+ .rate_limiter
+ .check_rate_limit(&format!("client:{}", client_id), "/token")
+ {
+ let _ = self.container.audit_logger.log_event(
+ "token_rate_limited",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
+ return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
+ }
+
+ // Get and validate authorization code
+ let auth_code = match self.container.auth_code_repository.get_auth_code(code) {
+ Ok(Some(auth_code)) => auth_code,
+ Ok(None) => {
+ let _ = self.container.audit_logger.log_event(
+ "token_invalid_code",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(
+ self.error_response("invalid_grant", "Invalid or expired authorization code")
+ );
+ }
+ Err(_) => {
+ return Err(self.error_response("server_error", "Internal server error"));
+ }
+ };
+
+ // Validate code hasn't been used and hasn't expired
+ if auth_code.is_used {
+ let _ = self.container.audit_logger.log_event(
+ "token_code_reuse",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(self.error_response("invalid_grant", "Authorization code already used"));
+ }
+
+ if Utc::now() > auth_code.expires_at {
+ let _ = self.container.audit_logger.log_event(
+ "token_code_expired",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(self.error_response("invalid_grant", "Authorization code expired"));
+ }
+
+ if auth_code.client_id != client_id {
+ let _ = self.container.audit_logger.log_event(
+ "token_client_mismatch",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "Client ID mismatch"));
+ }
+
+ // PKCE validation if code challenge was provided
+ if let Some(code_challenge) = &auth_code.code_challenge {
+ let code_verifier = params.get("code_verifier").ok_or_else(|| {
+ self.error_response("invalid_request", "Missing code_verifier for PKCE")
+ })?;
+
+ let challenge_method = auth_code
+ .code_challenge_method
+ .as_ref()
+ .and_then(|method| CodeChallengeMethod::from_str(method).ok())
+ .unwrap_or(CodeChallengeMethod::Plain);
+
+ if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method)
+ {
+ let _ = self.container.audit_logger.log_event(
+ "token_pkce_verification_failed",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
+ return Err(self.error_response("invalid_grant", "PKCE verification failed"));
+ }
+ }
+
+ // Mark code as used
+ if let Err(_) = self
+ .container
+ .auth_code_repository
+ .mark_auth_code_used(code)
+ {
+ return Err(self.error_response("server_error", "Failed to mark code as used"));
+ }
+
+ // Generate tokens using injected service
+ let token_id = Uuid::new_v4().to_string();
+ let access_token = self.container.token_generator.generate_access_token(
+ &auth_code.user_id,
+ &client_id,
+ &auth_code.scope,
+ &token_id,
+ )?;
+ let refresh_token = self.container.token_generator.generate_refresh_token(
+ &client_id,
+ &auth_code.user_id,
+ &auth_code.scope,
+ )?;
+
+ // Store token in database for revocation/introspection
+ let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes()));
+ let db_access_token = DbAccessToken {
+ id: 0,
+ token_id: token_id.clone(),
+ client_id: client_id.clone(),
+ user_id: auth_code.user_id.clone(),
+ scope: auth_code.scope.clone(),
+ expires_at: Utc::now() + Duration::hours(1),
+ created_at: Utc::now(),
+ is_revoked: false,
+ token_hash,
+ };
+
+ if let Err(_) = self
+ .container
+ .token_repository
+ .create_access_token(&db_access_token)
+ {
+ return Err(self.error_response("server_error", "Failed to store access token"));
+ }
+
+ let token_response = TokenResponse {
+ access_token,
+ token_type: "Bearer".to_string(),
+ expires_in: 3600,
+ refresh_token: Some(refresh_token),
+ scope: auth_code.scope,
+ };
+
+ let _ = self.container.audit_logger.log_event(
+ "token_success",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
+
+ serde_json::to_string(&token_response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
+ }
+
+ fn handle_refresh_token_grant(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let _refresh_token = params
+ .get("refresh_token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?;
+
+ // Client authentication using injected service
+ let (client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Validate refresh token (implementation would verify token and get user info)
+ // For now, return a simple refresh token response
+ let new_token_id = Uuid::new_v4().to_string();
+ let access_token = self.container.token_generator.generate_access_token(
+ "test_user",
+ &client_id,
+ &None,
+ &new_token_id,
+ )?;
+ let new_refresh_token = self.container.token_generator.generate_refresh_token(
+ &client_id,
+ "test_user",
+ &None,
+ )?;
+
+ let token_response = TokenResponse {
+ access_token,
+ token_type: "Bearer".to_string(),
+ expires_in: 3600,
+ refresh_token: Some(new_refresh_token),
+ scope: None,
+ };
+
+ let _ = self.container.audit_logger.log_event(
+ "refresh_success",
+ Some(&client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
+
+ serde_json::to_string(&token_response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
+ }
+
+ pub fn handle_token_introspection(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<String, String> {
+ let token = params
+ .get("token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?;
+
+ // Authenticate the client making the introspection request using injected service
+ let (_client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Look up token in database using repository
+ let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
+ let db_token = self
+ .container
+ .token_repository
+ .get_access_token(&token_hash)
+ .ok()
+ .flatten();
+
+ let response = if let Some(db_token) = db_token {
+ if !db_token.is_revoked && Utc::now() < db_token.expires_at {
+ TokenIntrospectionResponse {
+ active: true,
+ client_id: Some(db_token.client_id.clone()),
+ username: Some(db_token.user_id.clone()),
+ scope: db_token.scope.clone(),
+ exp: Some(db_token.expires_at.timestamp() as u64),
+ iat: Some(db_token.created_at.timestamp() as u64),
+ sub: Some(db_token.user_id),
+ aud: Some(db_token.client_id),
+ iss: Some(self.container.config.issuer_url.clone()),
+ jti: Some(db_token.token_id),
+ }
+ } else {
+ TokenIntrospectionResponse::inactive()
+ }
+ } else {
+ TokenIntrospectionResponse::inactive()
+ };
+
+ serde_json::to_string(&response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize response"))
+ }
+
+ pub fn handle_token_revocation(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(), String> {
+ let token = params
+ .get("token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?;
+
+ // Authenticate the client making the revocation request using injected service
+ let (client_id, _client_secret) = self
+ .container
+ .client_authenticator
+ .authenticate(params, auth_header)
+ .map_err(|e| self.error_response("invalid_client", &e))?;
+
+ // Revoke token in database using repository
+ let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
+ let _ = self
+ .container
+ .token_repository
+ .revoke_access_token(&token_hash); // Ignore errors as per RFC 7009
+
+ let _ = self.container.audit_logger.log_event(
+ "token_revoked",
+ Some(&client_id),
+ None,
+ None,
+ true,
+ None,
+ );
+
+ Ok(())
+ }
+
+ fn error_response(&self, error: &str, description: &str) -> String {
+ let error_resp = ErrorResponse {
+ error: error.to_string(),
+ error_description: Some(description.to_string()),
+ error_uri: None,
+ };
+ serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string())
+ }
+
+ /// Cleanup expired data using repositories
+ pub fn cleanup_expired_data(&self) -> Result<()> {
+ self.container.cleanup_expired_data()
+ }
+}
diff --git a/src/oauth/types.rs b/src/oauth/types.rs
index 4f2c363..3d1c581 100644
--- a/src/oauth/types.rs
+++ b/src/oauth/types.rs
@@ -76,6 +76,23 @@ pub struct TokenIntrospectionResponse {
pub jti: Option<String>,
}
+impl TokenIntrospectionResponse {
+ pub fn inactive() -> Self {
+ Self {
+ active: false,
+ client_id: None,
+ username: None,
+ scope: None,
+ exp: None,
+ iat: None,
+ sub: None,
+ aud: None,
+ iss: None,
+ jti: None,
+ }
+ }
+}
+
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenRevocationRequest {
pub token: String,