summaryrefslogtreecommitdiff
path: root/src/oauth
diff options
context:
space:
mode:
Diffstat (limited to 'src/oauth')
-rw-r--r--src/oauth/mod.rs2
-rw-r--r--src/oauth/pkce.rs156
-rw-r--r--src/oauth/server.rs469
-rw-r--r--src/oauth/types.rs45
4 files changed, 620 insertions, 52 deletions
diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs
index 3a0d861..7fd0d7b 100644
--- a/src/oauth/mod.rs
+++ b/src/oauth/mod.rs
@@ -1,5 +1,7 @@
+pub mod pkce;
pub mod server;
pub mod types;
+pub use pkce::{CodeChallengeMethod, verify_code_challenge, generate_code_verifier, generate_code_challenge};
pub use server::OAuthServer;
pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse};
diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs
new file mode 100644
index 0000000..c943844
--- /dev/null
+++ b/src/oauth/pkce.rs
@@ -0,0 +1,156 @@
+use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
+use sha2::{Digest, Sha256};
+use anyhow::{anyhow, Result};
+
+#[derive(Debug, Clone, PartialEq)]
+pub enum CodeChallengeMethod {
+ Plain,
+ S256,
+}
+
+impl CodeChallengeMethod {
+ pub fn from_str(s: &str) -> Result<Self> {
+ match s {
+ "plain" => Ok(CodeChallengeMethod::Plain),
+ "S256" => Ok(CodeChallengeMethod::S256),
+ _ => Err(anyhow!("Unsupported code challenge method: {}", s)),
+ }
+ }
+
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ CodeChallengeMethod::Plain => "plain",
+ CodeChallengeMethod::S256 => "S256",
+ }
+ }
+}
+
+pub fn verify_code_challenge(
+ code_verifier: &str,
+ code_challenge: &str,
+ method: &CodeChallengeMethod,
+) -> Result<bool> {
+ // Validate code verifier format (RFC 7636 Section 4.1)
+ if code_verifier.len() < 43 || code_verifier.len() > 128 {
+ return Err(anyhow!("Code verifier length must be between 43 and 128 characters"));
+ }
+
+ // Code verifier must only contain unreserved characters
+ if !code_verifier.chars().all(|c| {
+ c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
+ }) {
+ return Err(anyhow!("Code verifier contains invalid characters"));
+ }
+
+ let computed_challenge = match method {
+ CodeChallengeMethod::Plain => code_verifier.to_string(),
+ CodeChallengeMethod::S256 => {
+ let mut hasher = Sha256::new();
+ hasher.update(code_verifier.as_bytes());
+ URL_SAFE_NO_PAD.encode(hasher.finalize())
+ }
+ };
+
+ Ok(computed_challenge == code_challenge)
+}
+
+pub fn generate_code_verifier() -> String {
+ use rand::Rng;
+ let mut rng = rand::thread_rng();
+
+ // Generate 32 random bytes and encode them
+ let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect();
+ URL_SAFE_NO_PAD.encode(&bytes)
+}
+
+pub fn generate_code_challenge(verifier: &str, method: &CodeChallengeMethod) -> String {
+ match method {
+ CodeChallengeMethod::Plain => verifier.to_string(),
+ CodeChallengeMethod::S256 => {
+ let mut hasher = Sha256::new();
+ hasher.update(verifier.as_bytes());
+ URL_SAFE_NO_PAD.encode(hasher.finalize())
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_code_challenge_method_from_str() {
+ assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain);
+ assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256);
+ assert!(CodeChallengeMethod::from_str("invalid").is_err());
+ }
+
+ #[test]
+ fn test_code_challenge_method_as_str() {
+ assert_eq!(CodeChallengeMethod::Plain.as_str(), "plain");
+ assert_eq!(CodeChallengeMethod::S256.as_str(), "S256");
+ }
+
+ #[test]
+ fn test_verify_code_challenge_plain() {
+ let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+ let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+
+ assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap());
+ assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap());
+ }
+
+ #[test]
+ fn test_verify_code_challenge_s256() {
+ let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+ let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
+
+ assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap());
+ assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap());
+ }
+
+ #[test]
+ fn test_verify_code_challenge_invalid_verifier() {
+ // Too short
+ assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err());
+
+ // Invalid characters
+ assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err());
+ }
+
+ #[test]
+ fn test_generate_code_verifier() {
+ let verifier = generate_code_verifier();
+ assert!(verifier.len() >= 43);
+ assert!(verifier.len() <= 128);
+
+ // Should only contain valid characters
+ assert!(verifier.chars().all(|c| {
+ c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
+ }));
+ }
+
+ #[test]
+ fn test_generate_code_challenge() {
+ let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+
+ let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain);
+ assert_eq!(plain_challenge, verifier);
+
+ let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256);
+ assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
+ }
+
+ #[test]
+ fn test_round_trip() {
+ let verifier = generate_code_verifier();
+
+ // Test with S256
+ let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256);
+ assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap());
+
+ // Test with Plain
+ let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain);
+ assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap());
+ }
+} \ No newline at end of file
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 243fdba..7552f00 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -1,29 +1,36 @@
use crate::clients::{parse_basic_auth, ClientManager};
use crate::config::Config;
+use crate::database::{Database, DbAuthCode, DbAccessToken, DbAuditLog};
use crate::keys::KeyManager;
-use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse};
+use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge};
+use crate::oauth::types::{Claims, ErrorResponse, TokenResponse, TokenIntrospectionResponse};
+use anyhow::{anyhow, Result};
+use chrono::{Duration, Utc};
use jsonwebtoken::{encode, Algorithm, Header};
+use sha2::{Digest, Sha256};
use std::collections::HashMap;
+use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};
use url::Url;
use uuid::Uuid;
pub struct OAuthServer {
config: Config,
- key_manager: std::sync::Mutex<KeyManager>,
- auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>,
- client_manager: std::sync::Mutex<ClientManager>,
+ key_manager: Arc<Mutex<KeyManager>>,
+ client_manager: Arc<Mutex<ClientManager>>,
+ database: Arc<Mutex<Database>>,
}
impl OAuthServer {
- pub fn new(config: &Config) -> Result<Self, Box<dyn std::error::Error>> {
- let key_manager = KeyManager::new()?;
- let client_manager = ClientManager::new();
+ pub fn new(config: &Config) -> Result<Self> {
+ let database = Arc::new(Mutex::new(Database::new(&config.database_path)?));
+ let key_manager = Arc::new(Mutex::new(KeyManager::new(database.clone())?));
+ let client_manager = Arc::new(Mutex::new(ClientManager::new(database.clone())?));
Ok(Self {
- key_manager: std::sync::Mutex::new(key_manager),
- auth_codes: std::sync::Mutex::new(HashMap::new()),
- client_manager: std::sync::Mutex::new(client_manager),
+ key_manager,
+ client_manager,
+ database,
config: config.clone(),
})
}
@@ -36,7 +43,7 @@ impl OAuthServer {
}
}
- pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> {
+ pub fn handle_authorize(&self, params: &HashMap<String, String>, ip_address: Option<String>) -> Result<String, String> {
let client_id = params
.get("client_id")
.ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
@@ -49,48 +56,90 @@ impl OAuthServer {
.get("response_type")
.ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?;
+ // Rate limiting check
+ if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/authorize") {
+ self.audit_log("authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string()));
+ return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
+ }
+
// Validate client exists
- let client_manager = self.client_manager.lock().unwrap();
- let _client = client_manager
- .get_client(client_id)
- .ok_or_else(|| self.error_response("invalid_client", "Invalid client_id"))?;
+ let client = {
+ let mut client_manager = self.client_manager.lock().unwrap();
+ match client_manager.get_client_from_db(client_id) {
+ Ok(Some(client)) => client,
+ Ok(None) => {
+ self.audit_log("authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None);
+ return Err(self.error_response("invalid_client", "Invalid client_id"));
+ }
+ Err(_) => {
+ return Err(self.error_response("server_error", "Internal server error"));
+ }
+ }
+ };
// Validate redirect URI is registered for this client
- if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) {
- return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
+ {
+ let mut client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) {
+ self.audit_log("authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri));
+ return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
+ }
}
// Validate requested scopes
let scope = params.get("scope").cloned();
- if !client_manager.is_scope_valid(client_id, &scope) {
- return Err(self.error_response("invalid_scope", "Invalid scope"));
+ {
+ let mut client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.is_scope_valid(client_id, &scope) {
+ self.audit_log("authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref());
+ return Err(self.error_response("invalid_scope", "Invalid scope"));
+ }
}
if response_type != "code" {
+ self.audit_log("authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type));
return Err(self.error_response(
"unsupported_response_type",
"Only code response type supported",
));
}
+ // PKCE validation (RFC 7636)
+ let code_challenge = params.get("code_challenge");
+ let code_challenge_method = params.get("code_challenge_method")
+ .map(|method| CodeChallengeMethod::from_str(method))
+ .transpose()
+ .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?;
+
+ // For public clients, PKCE is required
+ if client.client_id.starts_with("public_") && code_challenge.is_none() {
+ self.audit_log("authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None);
+ return Err(self.error_response("invalid_request", "PKCE required for public clients"));
+ }
+
let code = Uuid::new_v4().to_string();
- let expires_at = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap()
- .as_secs()
- + 600;
+ let expires_at = Utc::now() + Duration::minutes(10); // 10 minute expiration
- let auth_code = AuthCode {
+ let db_auth_code = DbAuthCode {
+ id: 0, // Will be set by database
+ code: code.clone(),
client_id: client_id.clone(),
+ user_id: "test_user".to_string(), // In a real implementation, this would come from authentication
redirect_uri: redirect_uri.clone(),
- scope: scope,
+ scope: scope.clone(),
expires_at,
- user_id: "test_user".to_string(),
+ created_at: Utc::now(),
+ is_used: false,
+ code_challenge: code_challenge.cloned(),
+ code_challenge_method: code_challenge_method.as_ref().map(|m| m.as_str().to_string()),
};
+ // Save to database
{
- let mut codes = self.auth_codes.lock().unwrap();
- codes.insert(code.clone(), auth_code);
+ let db = self.database.lock().unwrap();
+ if let Err(_) = db.create_auth_code(&db_auth_code) {
+ return Err(self.error_response("server_error", "Failed to create authorization code"));
+ }
}
let mut redirect_url = Url::parse(redirect_uri)
@@ -102,6 +151,8 @@ impl OAuthServer {
redirect_url.query_pairs_mut().append_pair("state", state);
}
+ self.audit_log("authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None);
+
Ok(redirect_url.to_string())
}
@@ -109,24 +160,36 @@ impl OAuthServer {
&self,
params: &HashMap<String, String>,
auth_header: Option<&str>,
+ ip_address: Option<String>,
) -> Result<String, String> {
let grant_type = params
.get("grant_type")
.ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
- if grant_type != "authorization_code" {
- return Err(self.error_response(
- "unsupported_grant_type",
- "Only authorization_code grant type supported",
- ));
+ match grant_type.as_str() {
+ "authorization_code" => self.handle_authorization_code_grant(params, auth_header, ip_address),
+ "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address),
+ _ => {
+ self.audit_log("token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type));
+ Err(self.error_response(
+ "unsupported_grant_type",
+ "Unsupported grant type",
+ ))
+ }
}
+ }
+ fn handle_authorization_code_grant(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
let code = params
.get("code")
.ok_or_else(|| self.error_response("invalid_request", "Missing code"))?;
// Client authentication - RFC 6749 Section 3.2.1
- // Clients can authenticate via HTTP Basic Auth or form parameters
let (client_id, client_secret) = if let Some(auth_header) = auth_header {
// HTTP Basic Authentication (preferred method)
parse_basic_auth(auth_header).ok_or_else(|| {
@@ -143,52 +206,293 @@ impl OAuthServer {
(client_id.clone(), client_secret.clone())
};
+ // Rate limiting check
+ if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") {
+ self.audit_log("token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string()));
+ return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
+ }
+
// Authenticate the client
- let client_manager = self.client_manager.lock().unwrap();
- if !client_manager.authenticate_client(&client_id, &client_secret) {
- return Err(self.error_response("invalid_client", "Client authentication failed"));
+ {
+ let mut client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.authenticate_client(&client_id, &client_secret) {
+ self.audit_log("token_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None);
+ return Err(self.error_response("invalid_client", "Client authentication failed"));
+ }
}
+ // Get and validate authorization code
let auth_code = {
- let mut codes = self.auth_codes.lock().unwrap();
- codes.remove(code).ok_or_else(|| {
- self.error_response("invalid_grant", "Invalid or expired authorization code")
- })?
+ let db = self.database.lock().unwrap();
+ match db.get_auth_code(code) {
+ Ok(Some(auth_code)) => auth_code,
+ Ok(None) => {
+ self.audit_log("token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code));
+ return Err(self.error_response("invalid_grant", "Invalid or expired authorization code"));
+ }
+ Err(_) => {
+ return Err(self.error_response("server_error", "Internal server error"));
+ }
+ }
};
+ // Validate code hasn't been used and hasn't expired
+ if auth_code.is_used {
+ self.audit_log("token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code));
+ return Err(self.error_response("invalid_grant", "Authorization code already used"));
+ }
+
+ if Utc::now() > auth_code.expires_at {
+ self.audit_log("token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code));
+ return Err(self.error_response("invalid_grant", "Authorization code expired"));
+ }
+
if auth_code.client_id != client_id {
+ self.audit_log("token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None);
return Err(self.error_response("invalid_grant", "Client ID mismatch"));
}
- let now = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap()
- .as_secs();
+ // PKCE validation if code challenge was provided
+ if let Some(code_challenge) = &auth_code.code_challenge {
+ let code_verifier = params.get("code_verifier").ok_or_else(|| {
+ self.error_response("invalid_request", "Missing code_verifier for PKCE")
+ })?;
- if now > auth_code.expires_at {
- return Err(self.error_response("invalid_grant", "Authorization code expired"));
+ let challenge_method = auth_code.code_challenge_method
+ .as_ref()
+ .and_then(|method| CodeChallengeMethod::from_str(method).ok())
+ .unwrap_or(CodeChallengeMethod::Plain);
+
+ if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) {
+ self.audit_log("token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None);
+ return Err(self.error_response("invalid_grant", "PKCE verification failed"));
+ }
+ }
+
+ // Mark code as used
+ {
+ let db = self.database.lock().unwrap();
+ if let Err(_) = db.mark_auth_code_used(code) {
+ return Err(self.error_response("server_error", "Failed to mark code as used"));
+ }
}
- let access_token =
- self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope)?;
+ // Generate tokens
+ let token_id = Uuid::new_v4().to_string();
+ let access_token = self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope, &token_id)?;
+ let refresh_token = self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?;
+
+ // Store token in database for revocation/introspection
+ let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes()));
+ let db_access_token = DbAccessToken {
+ id: 0,
+ token_id: token_id.clone(),
+ client_id: client_id.clone(),
+ user_id: auth_code.user_id.clone(),
+ scope: auth_code.scope.clone(),
+ expires_at: Utc::now() + Duration::hours(1),
+ created_at: Utc::now(),
+ is_revoked: false,
+ token_hash,
+ };
+
+ {
+ let db = self.database.lock().unwrap();
+ if let Err(_) = db.create_access_token(&db_access_token) {
+ return Err(self.error_response("server_error", "Failed to store access token"));
+ }
+ }
let token_response = TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: 3600,
- refresh_token: None,
+ refresh_token: Some(refresh_token),
scope: auth_code.scope,
};
+ self.audit_log("token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None);
+
serde_json::to_string(&token_response)
.map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
}
+ fn handle_refresh_token_grant(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
+ let _refresh_token = params
+ .get("refresh_token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?;
+
+ // Client authentication
+ let (client_id, client_secret) = if let Some(auth_header) = auth_header {
+ parse_basic_auth(auth_header).ok_or_else(|| {
+ self.error_response("invalid_client", "Invalid Authorization header")
+ })?
+ } else {
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+ let client_secret = params
+ .get("client_secret")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?;
+ (client_id.clone(), client_secret.clone())
+ };
+
+ // Authenticate the client
+ {
+ let mut client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.authenticate_client(&client_id, &client_secret) {
+ self.audit_log("refresh_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None);
+ return Err(self.error_response("invalid_client", "Client authentication failed"));
+ }
+ }
+
+ // Validate refresh token (implementation would verify token and get user info)
+ // For now, return a simple refresh token response
+ let new_token_id = Uuid::new_v4().to_string();
+ let access_token = self.generate_access_token("test_user", &client_id, &None, &new_token_id)?;
+ let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?;
+
+ let token_response = TokenResponse {
+ access_token,
+ token_type: "Bearer".to_string(),
+ expires_in: 3600,
+ refresh_token: Some(new_refresh_token),
+ scope: None,
+ };
+
+ self.audit_log("refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None);
+
+ serde_json::to_string(&token_response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
+ }
+
+ pub fn handle_token_introspection(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<String, String> {
+ let token = params
+ .get("token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?;
+
+ // Authenticate the client making the introspection request
+ let (client_id, client_secret) = if let Some(auth_header) = auth_header {
+ parse_basic_auth(auth_header).ok_or_else(|| {
+ self.error_response("invalid_client", "Invalid Authorization header")
+ })?
+ } else {
+ return Err(self.error_response("invalid_client", "Client authentication required"));
+ };
+
+ {
+ let mut client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.authenticate_client(&client_id, &client_secret) {
+ return Err(self.error_response("invalid_client", "Client authentication failed"));
+ }
+ }
+
+ // Look up token in database
+ let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
+ let db_token = {
+ let db = self.database.lock().unwrap();
+ db.get_access_token(&token_hash).ok().flatten()
+ };
+
+ let response = if let Some(db_token) = db_token {
+ if !db_token.is_revoked && Utc::now() < db_token.expires_at {
+ TokenIntrospectionResponse {
+ active: true,
+ client_id: Some(db_token.client_id.clone()),
+ username: Some(db_token.user_id.clone()),
+ scope: db_token.scope.clone(),
+ exp: Some(db_token.expires_at.timestamp() as u64),
+ iat: Some(db_token.created_at.timestamp() as u64),
+ sub: Some(db_token.user_id),
+ aud: Some(db_token.client_id),
+ iss: Some(self.config.issuer_url.clone()),
+ jti: Some(db_token.token_id),
+ }
+ } else {
+ TokenIntrospectionResponse {
+ active: false,
+ client_id: None,
+ username: None,
+ scope: None,
+ exp: None,
+ iat: None,
+ sub: None,
+ aud: None,
+ iss: None,
+ jti: None,
+ }
+ }
+ } else {
+ TokenIntrospectionResponse {
+ active: false,
+ client_id: None,
+ username: None,
+ scope: None,
+ exp: None,
+ iat: None,
+ sub: None,
+ aud: None,
+ iss: None,
+ jti: None,
+ }
+ };
+
+ serde_json::to_string(&response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize response"))
+ }
+
+ pub fn handle_token_revocation(
+ &self,
+ params: &HashMap<String, String>,
+ auth_header: Option<&str>,
+ ) -> Result<(), String> {
+ let token = params
+ .get("token")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing token"))?;
+
+ // Authenticate the client making the revocation request
+ let (client_id, client_secret) = if let Some(auth_header) = auth_header {
+ parse_basic_auth(auth_header).ok_or_else(|| {
+ self.error_response("invalid_client", "Invalid Authorization header")
+ })?
+ } else {
+ return Err(self.error_response("invalid_client", "Client authentication required"));
+ };
+
+ {
+ let mut client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.authenticate_client(&client_id, &client_secret) {
+ return Err(self.error_response("invalid_client", "Client authentication failed"));
+ }
+ }
+
+ // Revoke token in database
+ let token_hash = format!("{:x}", Sha256::digest(token.as_bytes()));
+ {
+ let db = self.database.lock().unwrap();
+ let _ = db.revoke_access_token(&token_hash); // Ignore errors as per RFC 7009
+ }
+
+ self.audit_log("token_revoked", Some(&client_id), None, None, true, None);
+
+ Ok(())
+ }
+
fn generate_access_token(
&self,
user_id: &str,
client_id: &str,
scope: &Option<String>,
+ token_id: &str,
) -> Result<String, String> {
let mut key_manager = self.key_manager.lock().unwrap();
@@ -215,6 +519,7 @@ impl OAuthServer {
exp: now + 3600,
iat: now,
scope: scope.clone(),
+ jti: Some(token_id.to_string()),
};
let mut header = Header::new(Algorithm::RS256);
@@ -224,11 +529,71 @@ impl OAuthServer {
.map_err(|_| self.error_response("server_error", "Failed to generate token"))
}
+ fn generate_refresh_token(
+ &self,
+ _client_id: &str,
+ _user_id: &str,
+ _scope: &Option<String>,
+ ) -> Result<String, String> {
+ // For now, return a simple UUID-based refresh token
+ // In production, this should be a proper JWT or encrypted token
+ Ok(Uuid::new_v4().to_string())
+ }
+
+ fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> {
+ let db = self.database.lock().unwrap();
+ let count = db.increment_rate_limit(identifier, endpoint, 1)?;
+
+ if count > self.config.rate_limit_requests_per_minute as i32 {
+ return Err(anyhow!("Rate limit exceeded"));
+ }
+
+ Ok(())
+ }
+
+ fn audit_log(&self, event_type: &str, client_id: Option<&str>, user_id: Option<&str>, ip_address: Option<&str>, success: bool, details: Option<&str>) {
+ if !self.config.enable_audit_logging {
+ return;
+ }
+
+ let log = DbAuditLog {
+ id: 0,
+ event_type: event_type.to_string(),
+ client_id: client_id.map(|s| s.to_string()),
+ user_id: user_id.map(|s| s.to_string()),
+ ip_address: ip_address.map(|s| s.to_string()),
+ user_agent: None, // Could be passed in from HTTP layer
+ details: details.map(|s| s.to_string()),
+ created_at: Utc::now(),
+ success,
+ };
+
+ let db = self.database.lock().unwrap();
+ let _ = db.create_audit_log(&log); // Ignore errors in audit logging
+ }
+
fn error_response(&self, error: &str, description: &str) -> String {
let error_resp = ErrorResponse {
error: error.to_string(),
error_description: Some(description.to_string()),
+ error_uri: None,
};
serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string())
}
-}
+
+ // Cleanup expired data
+ pub fn cleanup_expired_data(&self) -> Result<()> {
+ let db = self.database.lock().unwrap();
+
+ // Cleanup expired authorization codes
+ let _ = db.cleanup_expired_codes();
+
+ // Cleanup expired tokens
+ let _ = db.cleanup_expired_tokens();
+
+ // Cleanup old audit logs (keep for 30 days)
+ let _ = db.cleanup_old_audit_logs(30);
+
+ Ok(())
+ }
+} \ No newline at end of file
diff --git a/src/oauth/types.rs b/src/oauth/types.rs
index 6c62edf..0f9be5c 100644
--- a/src/oauth/types.rs
+++ b/src/oauth/types.rs
@@ -1,4 +1,5 @@
use serde::{Deserialize, Serialize};
+use crate::oauth::pkce::CodeChallengeMethod;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
@@ -9,6 +10,8 @@ pub struct Claims {
pub iat: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub jti: Option<String>, // JWT ID for token tracking
}
#[derive(Debug, Serialize, Deserialize)]
@@ -27,6 +30,8 @@ pub struct ErrorResponse {
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_description: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub error_uri: Option<String>,
}
#[derive(Debug, Clone)]
@@ -36,4 +41,44 @@ pub struct AuthCode {
pub scope: Option<String>,
pub expires_at: u64,
pub user_id: String,
+ // PKCE support
+ pub code_challenge: Option<String>,
+ pub code_challenge_method: Option<CodeChallengeMethod>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct TokenIntrospectionRequest {
+ pub token: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub token_type_hint: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct TokenIntrospectionResponse {
+ pub active: bool,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub client_id: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub username: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub scope: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub exp: Option<u64>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub iat: Option<u64>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub sub: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub aud: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub iss: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub jti: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct TokenRevocationRequest {
+ pub token: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub token_type_hint: Option<String>,
}