summaryrefslogtreecommitdiff
path: root/src/oauth/server.rs
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 15:12:59 -0600
committermo khan <mo@mokhan.ca>2025-06-11 15:12:59 -0600
commit4435ee26b79648e92d0f172e42f9e6629e955505 (patch)
tree0720fd07c879a58672fcfcb2e45ed1161430f039 /src/oauth/server.rs
parent39c67cfc6c74bf4b26ba455f3adda1241aea35ea (diff)
chore: rustfmt and include Connection: header in responses
Diffstat (limited to 'src/oauth/server.rs')
-rw-r--r--src/oauth/server.rs240
1 files changed, 196 insertions, 44 deletions
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 7552f00..7fd8b9c 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -1,9 +1,9 @@
use crate::clients::{parse_basic_auth, ClientManager};
use crate::config::Config;
-use crate::database::{Database, DbAuthCode, DbAccessToken, DbAuditLog};
+use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode};
use crate::keys::KeyManager;
-use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge};
-use crate::oauth::types::{Claims, ErrorResponse, TokenResponse, TokenIntrospectionResponse};
+use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod};
+use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse};
use anyhow::{anyhow, Result};
use chrono::{Duration, Utc};
use jsonwebtoken::{encode, Algorithm, Header};
@@ -43,7 +43,11 @@ impl OAuthServer {
}
}
- pub fn handle_authorize(&self, params: &HashMap<String, String>, ip_address: Option<String>) -> Result<String, String> {
+ pub fn handle_authorize(
+ &self,
+ params: &HashMap<String, String>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
let client_id = params
.get("client_id")
.ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
@@ -58,7 +62,14 @@ impl OAuthServer {
// Rate limiting check
if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/authorize") {
- self.audit_log("authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string()));
+ self.audit_log(
+ "authorize_rate_limited",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
}
@@ -68,7 +79,14 @@ impl OAuthServer {
match client_manager.get_client_from_db(client_id) {
Ok(Some(client)) => client,
Ok(None) => {
- self.audit_log("authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "authorize_invalid_client",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_client", "Invalid client_id"));
}
Err(_) => {
@@ -81,7 +99,14 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) {
- self.audit_log("authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri));
+ self.audit_log(
+ "authorize_invalid_redirect_uri",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(redirect_uri),
+ );
return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
}
}
@@ -91,13 +116,27 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.is_scope_valid(client_id, &scope) {
- self.audit_log("authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref());
+ self.audit_log(
+ "authorize_invalid_scope",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ scope.as_deref(),
+ );
return Err(self.error_response("invalid_scope", "Invalid scope"));
}
}
if response_type != "code" {
- self.audit_log("authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type));
+ self.audit_log(
+ "authorize_unsupported_response_type",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(response_type),
+ );
return Err(self.error_response(
"unsupported_response_type",
"Only code response type supported",
@@ -106,14 +145,22 @@ impl OAuthServer {
// PKCE validation (RFC 7636)
let code_challenge = params.get("code_challenge");
- let code_challenge_method = params.get("code_challenge_method")
+ let code_challenge_method = params
+ .get("code_challenge_method")
.map(|method| CodeChallengeMethod::from_str(method))
.transpose()
.map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?;
// For public clients, PKCE is required
if client.client_id.starts_with("public_") && code_challenge.is_none() {
- self.audit_log("authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "authorize_missing_pkce",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_request", "PKCE required for public clients"));
}
@@ -131,14 +178,18 @@ impl OAuthServer {
created_at: Utc::now(),
is_used: false,
code_challenge: code_challenge.cloned(),
- code_challenge_method: code_challenge_method.as_ref().map(|m| m.as_str().to_string()),
+ code_challenge_method: code_challenge_method
+ .as_ref()
+ .map(|m| m.as_str().to_string()),
};
// Save to database
{
let db = self.database.lock().unwrap();
if let Err(_) = db.create_auth_code(&db_auth_code) {
- return Err(self.error_response("server_error", "Failed to create authorization code"));
+ return Err(
+ self.error_response("server_error", "Failed to create authorization code")
+ );
}
}
@@ -151,7 +202,14 @@ impl OAuthServer {
redirect_url.query_pairs_mut().append_pair("state", state);
}
- self.audit_log("authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None);
+ self.audit_log(
+ "authorize_success",
+ Some(client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
Ok(redirect_url.to_string())
}
@@ -167,14 +225,20 @@ impl OAuthServer {
.ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
match grant_type.as_str() {
- "authorization_code" => self.handle_authorization_code_grant(params, auth_header, ip_address),
+ "authorization_code" => {
+ self.handle_authorization_code_grant(params, auth_header, ip_address)
+ }
"refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address),
_ => {
- self.audit_log("token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type));
- Err(self.error_response(
- "unsupported_grant_type",
- "Unsupported grant type",
- ))
+ self.audit_log(
+ "token_unsupported_grant_type",
+ None,
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(grant_type),
+ );
+ Err(self.error_response("unsupported_grant_type", "Unsupported grant type"))
}
}
}
@@ -208,7 +272,14 @@ impl OAuthServer {
// Rate limiting check
if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") {
- self.audit_log("token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string()));
+ self.audit_log(
+ "token_rate_limited",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
}
@@ -216,7 +287,14 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.authenticate_client(&client_id, &client_secret) {
- self.audit_log("token_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "token_invalid_client",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_client", "Client authentication failed"));
}
}
@@ -227,8 +305,16 @@ impl OAuthServer {
match db.get_auth_code(code) {
Ok(Some(auth_code)) => auth_code,
Ok(None) => {
- self.audit_log("token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code));
- return Err(self.error_response("invalid_grant", "Invalid or expired authorization code"));
+ self.audit_log(
+ "token_invalid_code",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(self
+ .error_response("invalid_grant", "Invalid or expired authorization code"));
}
Err(_) => {
return Err(self.error_response("server_error", "Internal server error"));
@@ -238,17 +324,38 @@ impl OAuthServer {
// Validate code hasn't been used and hasn't expired
if auth_code.is_used {
- self.audit_log("token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code));
+ self.audit_log(
+ "token_code_reuse",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
return Err(self.error_response("invalid_grant", "Authorization code already used"));
}
if Utc::now() > auth_code.expires_at {
- self.audit_log("token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code));
+ self.audit_log(
+ "token_code_expired",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
return Err(self.error_response("invalid_grant", "Authorization code expired"));
}
if auth_code.client_id != client_id {
- self.audit_log("token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None);
+ self.audit_log(
+ "token_client_mismatch",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_grant", "Client ID mismatch"));
}
@@ -258,13 +365,22 @@ impl OAuthServer {
self.error_response("invalid_request", "Missing code_verifier for PKCE")
})?;
- let challenge_method = auth_code.code_challenge_method
+ let challenge_method = auth_code
+ .code_challenge_method
.as_ref()
.and_then(|method| CodeChallengeMethod::from_str(method).ok())
.unwrap_or(CodeChallengeMethod::Plain);
- if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) {
- self.audit_log("token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None);
+ if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method)
+ {
+ self.audit_log(
+ "token_pkce_verification_failed",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_grant", "PKCE verification failed"));
}
}
@@ -279,8 +395,14 @@ impl OAuthServer {
// Generate tokens
let token_id = Uuid::new_v4().to_string();
- let access_token = self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope, &token_id)?;
- let refresh_token = self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?;
+ let access_token = self.generate_access_token(
+ &auth_code.user_id,
+ &client_id,
+ &auth_code.scope,
+ &token_id,
+ )?;
+ let refresh_token =
+ self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?;
// Store token in database for revocation/introspection
let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes()));
@@ -311,7 +433,14 @@ impl OAuthServer {
scope: auth_code.scope,
};
- self.audit_log("token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None);
+ self.audit_log(
+ "token_success",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
serde_json::to_string(&token_response)
.map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
@@ -346,7 +475,14 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.authenticate_client(&client_id, &client_secret) {
- self.audit_log("refresh_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "refresh_invalid_client",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_client", "Client authentication failed"));
}
}
@@ -354,7 +490,8 @@ impl OAuthServer {
// Validate refresh token (implementation would verify token and get user info)
// For now, return a simple refresh token response
let new_token_id = Uuid::new_v4().to_string();
- let access_token = self.generate_access_token("test_user", &client_id, &None, &new_token_id)?;
+ let access_token =
+ self.generate_access_token("test_user", &client_id, &None, &new_token_id)?;
let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?;
let token_response = TokenResponse {
@@ -365,7 +502,14 @@ impl OAuthServer {
scope: None,
};
- self.audit_log("refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None);
+ self.audit_log(
+ "refresh_success",
+ Some(&client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
serde_json::to_string(&token_response)
.map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
@@ -543,15 +687,23 @@ impl OAuthServer {
fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> {
let db = self.database.lock().unwrap();
let count = db.increment_rate_limit(identifier, endpoint, 1)?;
-
+
if count > self.config.rate_limit_requests_per_minute as i32 {
return Err(anyhow!("Rate limit exceeded"));
}
-
+
Ok(())
}
- fn audit_log(&self, event_type: &str, client_id: Option<&str>, user_id: Option<&str>, ip_address: Option<&str>, success: bool, details: Option<&str>) {
+ fn audit_log(
+ &self,
+ event_type: &str,
+ client_id: Option<&str>,
+ user_id: Option<&str>,
+ ip_address: Option<&str>,
+ success: bool,
+ details: Option<&str>,
+ ) {
if !self.config.enable_audit_logging {
return;
}
@@ -584,16 +736,16 @@ impl OAuthServer {
// Cleanup expired data
pub fn cleanup_expired_data(&self) -> Result<()> {
let db = self.database.lock().unwrap();
-
+
// Cleanup expired authorization codes
let _ = db.cleanup_expired_codes();
-
+
// Cleanup expired tokens
let _ = db.cleanup_expired_tokens();
-
+
// Cleanup old audit logs (keep for 30 days)
let _ = db.cleanup_old_audit_logs(30);
-
+
Ok(())
}
-} \ No newline at end of file
+}