summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 15:12:59 -0600
committermo khan <mo@mokhan.ca>2025-06-11 15:12:59 -0600
commit4435ee26b79648e92d0f172e42f9e6629e955505 (patch)
tree0720fd07c879a58672fcfcb2e45ed1161430f039 /src
parent39c67cfc6c74bf4b26ba455f3adda1241aea35ea (diff)
chore: rustfmt and include Connection: header in responses
Diffstat (limited to 'src')
-rw-r--r--src/bin/debug.rs36
-rw-r--r--src/bin/migrate.rs9
-rw-r--r--src/http/mod.rs53
-rw-r--r--src/oauth/mod.rs4
-rw-r--r--src/oauth/pkce.rs57
-rw-r--r--src/oauth/server.rs240
-rw-r--r--src/oauth/types.rs2
7 files changed, 307 insertions, 94 deletions
diff --git a/src/bin/debug.rs b/src/bin/debug.rs
new file mode 100644
index 0000000..6d80848
--- /dev/null
+++ b/src/bin/debug.rs
@@ -0,0 +1,36 @@
+fn main() {
+ let config = sts::Config::from_env();
+ println!("Config loaded: {}", config.bind_addr);
+ let server = sts::http::Server::new(config.clone());
+ println!("Server result: {:?}", server.is_ok());
+
+ if let Ok(server) = server {
+ let oauth_server = &server.oauth_server;
+ let jwks = oauth_server.get_jwks();
+ println!("JWKS length: {}", jwks.len());
+ println!(
+ "JWKS: {}",
+ if jwks.len() > 100 {
+ &jwks[..100]
+ } else {
+ &jwks
+ }
+ );
+ }
+
+ let metadata = serde_json::json!({
+ "issuer": config.issuer_url,
+ "authorization_endpoint": format!("{}/authorize", config.issuer_url),
+ "token_endpoint": format!("{}/token", config.issuer_url)
+ });
+ let metadata_str = metadata.to_string();
+ println!("Metadata length: {}", metadata_str.len());
+ println!(
+ "Metadata: {}",
+ if metadata_str.len() > 100 {
+ &metadata_str[..100]
+ } else {
+ &metadata_str
+ }
+ );
+}
diff --git a/src/bin/migrate.rs b/src/bin/migrate.rs
index 9afdbf0..9a0bab9 100644
--- a/src/bin/migrate.rs
+++ b/src/bin/migrate.rs
@@ -1,11 +1,11 @@
use anyhow::Result;
use rusqlite::Connection;
-use sts::{Config, MigrationRunner};
use std::env;
+use sts::{Config, MigrationRunner};
fn main() -> Result<()> {
let args: Vec<String> = env::args().collect();
-
+
if args.len() < 2 {
print_usage();
return Ok(());
@@ -29,7 +29,8 @@ fn main() -> Result<()> {
eprintln!("Usage: cargo run --bin migrate rollback <version>");
return Ok(());
}
- let version: i32 = args[2].parse()
+ let version: i32 = args[2]
+ .parse()
.map_err(|_| anyhow::anyhow!("Invalid version number: {}", args[2]))?;
runner.rollback_to_version(version)?;
}
@@ -58,4 +59,4 @@ fn print_usage() {
println!(" cargo run --bin migrate up");
println!(" cargo run --bin migrate status");
println!(" cargo run --bin migrate rollback 0");
-} \ No newline at end of file
+}
diff --git a/src/http/mod.rs b/src/http/mod.rs
index c8d485b..1bc7951 100644
--- a/src/http/mod.rs
+++ b/src/http/mod.rs
@@ -8,13 +8,14 @@ use url::Url;
pub struct Server {
config: Config,
- oauth_server: OAuthServer,
+ pub oauth_server: OAuthServer,
}
impl Server {
pub fn new(config: Config) -> Result<Server, Box<dyn std::error::Error>> {
Ok(Server {
- oauth_server: OAuthServer::new(&config).map_err(|e| format!("Failed to create OAuth server: {}", e))?,
+ oauth_server: OAuthServer::new(&config)
+ .map_err(|e| format!("Failed to create OAuth server: {}", e))?,
config,
})
}
@@ -69,7 +70,7 @@ impl Server {
// Extract IP address for audit logging
let ip_address = stream.peer_addr().ok().map(|addr| addr.ip().to_string());
-
+
match (method, path) {
("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"),
("GET", "/.well-known/oauth-authorization-server") => self.handle_metadata(&mut stream),
@@ -92,13 +93,13 @@ impl Server {
};
let response = format!(
- "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}",
+ "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
content_type,
contents.len(),
contents
);
let _ = stream.write_all(response.as_bytes());
- let _ = stream.flush();
+ let _ = stream.flush();
}
Err(_) => self.send_error_response(stream, 404, "Not Found"),
}
@@ -106,7 +107,7 @@ impl Server {
fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) {
let response = format!(
- "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}",
+ "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
message,
message.len(),
@@ -123,38 +124,38 @@ impl Server {
status_text: &str,
json: &str,
) {
- let security_headers = self.get_security_headers();
let response = format!(
- "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n{}\r\n{}",
+ "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
status,
status_text,
json.len(),
- security_headers,
json
);
let _ = stream.write_all(response.as_bytes());
let _ = stream.flush();
}
-
+
fn send_empty_response(&self, stream: &mut TcpStream, status: u16, status_text: &str) {
let security_headers = self.get_security_headers();
let response = format!(
- "HTTP/1.1 {} {}\r\nContent-Length: 0\r\n{}\r\n",
- status,
- status_text,
- security_headers
+ "HTTP/1.1 {} {}\r\nContent-Length: 0\r\nConnection: close\r\n{}\r\n",
+ status, status_text, security_headers
);
let _ = stream.write_all(response.as_bytes());
let _ = stream.flush();
}
-
+
fn get_security_headers(&self) -> String {
let cors_origin = if self.config.cors_allowed_origins.contains(&"*".to_string()) {
"*".to_string()
} else {
- self.config.cors_allowed_origins.first().unwrap_or(&"*".to_string()).clone()
+ self.config
+ .cors_allowed_origins
+ .first()
+ .unwrap_or(&"*".to_string())
+ .clone()
};
-
+
format!(
"Access-Control-Allow-Origin: {}\r\n\
Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n\
@@ -197,17 +198,21 @@ impl Server {
self.send_json_response(stream, 200, "OK", &jwks);
}
- fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>, ip_address: Option<String>) {
+ fn handle_authorize(
+ &self,
+ stream: &mut TcpStream,
+ params: &HashMap<String, String>,
+ ip_address: Option<String>,
+ ) {
match self.oauth_server.handle_authorize(params, ip_address) {
Ok(redirect_url) => {
let security_headers = self.get_security_headers();
let response = format!(
- "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n{}\r\n",
- redirect_url,
- security_headers
+ "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\nConnection: close\r\n{}\r\n",
+ redirect_url, security_headers
);
let _ = stream.write_all(response.as_bytes());
- let _ = stream.flush();
+ let _ = stream.flush();
}
Err(error_response) => {
self.send_json_response(stream, 400, "Bad Request", &error_response);
@@ -234,7 +239,7 @@ impl Server {
}
}
}
-
+
fn handle_introspect(&self, stream: &mut TcpStream, request: &str) {
let body = self.extract_body(request);
let form_params = self.parse_form_data(&body);
@@ -252,7 +257,7 @@ impl Server {
}
}
}
-
+
fn handle_revoke(&self, stream: &mut TcpStream, request: &str) {
let body = self.extract_body(request);
let form_params = self.parse_form_data(&body);
diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs
index 7fd0d7b..4b18bb3 100644
--- a/src/oauth/mod.rs
+++ b/src/oauth/mod.rs
@@ -2,6 +2,8 @@ pub mod pkce;
pub mod server;
pub mod types;
-pub use pkce::{CodeChallengeMethod, verify_code_challenge, generate_code_verifier, generate_code_challenge};
+pub use pkce::{
+ generate_code_challenge, generate_code_verifier, verify_code_challenge, CodeChallengeMethod,
+};
pub use server::OAuthServer;
pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse};
diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs
index c943844..406d364 100644
--- a/src/oauth/pkce.rs
+++ b/src/oauth/pkce.rs
@@ -1,6 +1,6 @@
-use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
-use sha2::{Digest, Sha256};
use anyhow::{anyhow, Result};
+use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
+use sha2::{Digest, Sha256};
#[derive(Debug, Clone, PartialEq)]
pub enum CodeChallengeMethod {
@@ -32,13 +32,16 @@ pub fn verify_code_challenge(
) -> Result<bool> {
// Validate code verifier format (RFC 7636 Section 4.1)
if code_verifier.len() < 43 || code_verifier.len() > 128 {
- return Err(anyhow!("Code verifier length must be between 43 and 128 characters"));
+ return Err(anyhow!(
+ "Code verifier length must be between 43 and 128 characters"
+ ));
}
// Code verifier must only contain unreserved characters
- if !code_verifier.chars().all(|c| {
- c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
- }) {
+ if !code_verifier
+ .chars()
+ .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
+ {
return Err(anyhow!("Code verifier contains invalid characters"));
}
@@ -57,7 +60,7 @@ pub fn verify_code_challenge(
pub fn generate_code_verifier() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
-
+
// Generate 32 random bytes and encode them
let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect();
URL_SAFE_NO_PAD.encode(&bytes)
@@ -80,8 +83,14 @@ mod tests {
#[test]
fn test_code_challenge_method_from_str() {
- assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain);
- assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256);
+ assert_eq!(
+ CodeChallengeMethod::from_str("plain").unwrap(),
+ CodeChallengeMethod::Plain
+ );
+ assert_eq!(
+ CodeChallengeMethod::from_str("S256").unwrap(),
+ CodeChallengeMethod::S256
+ );
assert!(CodeChallengeMethod::from_str("invalid").is_err());
}
@@ -95,7 +104,7 @@ mod tests {
fn test_verify_code_challenge_plain() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
-
+
assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap());
assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap());
}
@@ -104,7 +113,7 @@ mod tests {
fn test_verify_code_challenge_s256() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
-
+
assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap());
assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap());
}
@@ -113,9 +122,14 @@ mod tests {
fn test_verify_code_challenge_invalid_verifier() {
// Too short
assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err());
-
+
// Invalid characters
- assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err());
+ assert!(verify_code_challenge(
+ "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!",
+ "challenge",
+ &CodeChallengeMethod::Plain
+ )
+ .is_err());
}
#[test]
@@ -123,7 +137,7 @@ mod tests {
let verifier = generate_code_verifier();
assert!(verifier.len() >= 43);
assert!(verifier.len() <= 128);
-
+
// Should only contain valid characters
assert!(verifier.chars().all(|c| {
c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
@@ -133,24 +147,27 @@ mod tests {
#[test]
fn test_generate_code_challenge() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
-
+
let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain);
assert_eq!(plain_challenge, verifier);
-
+
let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256);
- assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
+ assert_eq!(
+ s256_challenge,
+ "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
+ );
}
#[test]
fn test_round_trip() {
let verifier = generate_code_verifier();
-
+
// Test with S256
let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256);
assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap());
-
+
// Test with Plain
let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain);
assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap());
}
-} \ No newline at end of file
+}
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 7552f00..7fd8b9c 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -1,9 +1,9 @@
use crate::clients::{parse_basic_auth, ClientManager};
use crate::config::Config;
-use crate::database::{Database, DbAuthCode, DbAccessToken, DbAuditLog};
+use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode};
use crate::keys::KeyManager;
-use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge};
-use crate::oauth::types::{Claims, ErrorResponse, TokenResponse, TokenIntrospectionResponse};
+use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod};
+use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse};
use anyhow::{anyhow, Result};
use chrono::{Duration, Utc};
use jsonwebtoken::{encode, Algorithm, Header};
@@ -43,7 +43,11 @@ impl OAuthServer {
}
}
- pub fn handle_authorize(&self, params: &HashMap<String, String>, ip_address: Option<String>) -> Result<String, String> {
+ pub fn handle_authorize(
+ &self,
+ params: &HashMap<String, String>,
+ ip_address: Option<String>,
+ ) -> Result<String, String> {
let client_id = params
.get("client_id")
.ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
@@ -58,7 +62,14 @@ impl OAuthServer {
// Rate limiting check
if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/authorize") {
- self.audit_log("authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string()));
+ self.audit_log(
+ "authorize_rate_limited",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
}
@@ -68,7 +79,14 @@ impl OAuthServer {
match client_manager.get_client_from_db(client_id) {
Ok(Some(client)) => client,
Ok(None) => {
- self.audit_log("authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "authorize_invalid_client",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_client", "Invalid client_id"));
}
Err(_) => {
@@ -81,7 +99,14 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) {
- self.audit_log("authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri));
+ self.audit_log(
+ "authorize_invalid_redirect_uri",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(redirect_uri),
+ );
return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
}
}
@@ -91,13 +116,27 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.is_scope_valid(client_id, &scope) {
- self.audit_log("authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref());
+ self.audit_log(
+ "authorize_invalid_scope",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ scope.as_deref(),
+ );
return Err(self.error_response("invalid_scope", "Invalid scope"));
}
}
if response_type != "code" {
- self.audit_log("authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type));
+ self.audit_log(
+ "authorize_unsupported_response_type",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(response_type),
+ );
return Err(self.error_response(
"unsupported_response_type",
"Only code response type supported",
@@ -106,14 +145,22 @@ impl OAuthServer {
// PKCE validation (RFC 7636)
let code_challenge = params.get("code_challenge");
- let code_challenge_method = params.get("code_challenge_method")
+ let code_challenge_method = params
+ .get("code_challenge_method")
.map(|method| CodeChallengeMethod::from_str(method))
.transpose()
.map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?;
// For public clients, PKCE is required
if client.client_id.starts_with("public_") && code_challenge.is_none() {
- self.audit_log("authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "authorize_missing_pkce",
+ Some(client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_request", "PKCE required for public clients"));
}
@@ -131,14 +178,18 @@ impl OAuthServer {
created_at: Utc::now(),
is_used: false,
code_challenge: code_challenge.cloned(),
- code_challenge_method: code_challenge_method.as_ref().map(|m| m.as_str().to_string()),
+ code_challenge_method: code_challenge_method
+ .as_ref()
+ .map(|m| m.as_str().to_string()),
};
// Save to database
{
let db = self.database.lock().unwrap();
if let Err(_) = db.create_auth_code(&db_auth_code) {
- return Err(self.error_response("server_error", "Failed to create authorization code"));
+ return Err(
+ self.error_response("server_error", "Failed to create authorization code")
+ );
}
}
@@ -151,7 +202,14 @@ impl OAuthServer {
redirect_url.query_pairs_mut().append_pair("state", state);
}
- self.audit_log("authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None);
+ self.audit_log(
+ "authorize_success",
+ Some(client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
Ok(redirect_url.to_string())
}
@@ -167,14 +225,20 @@ impl OAuthServer {
.ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
match grant_type.as_str() {
- "authorization_code" => self.handle_authorization_code_grant(params, auth_header, ip_address),
+ "authorization_code" => {
+ self.handle_authorization_code_grant(params, auth_header, ip_address)
+ }
"refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address),
_ => {
- self.audit_log("token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type));
- Err(self.error_response(
- "unsupported_grant_type",
- "Unsupported grant type",
- ))
+ self.audit_log(
+ "token_unsupported_grant_type",
+ None,
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(grant_type),
+ );
+ Err(self.error_response("unsupported_grant_type", "Unsupported grant type"))
}
}
}
@@ -208,7 +272,14 @@ impl OAuthServer {
// Rate limiting check
if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") {
- self.audit_log("token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string()));
+ self.audit_log(
+ "token_rate_limited",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(&e.to_string()),
+ );
return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded"));
}
@@ -216,7 +287,14 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.authenticate_client(&client_id, &client_secret) {
- self.audit_log("token_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "token_invalid_client",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_client", "Client authentication failed"));
}
}
@@ -227,8 +305,16 @@ impl OAuthServer {
match db.get_auth_code(code) {
Ok(Some(auth_code)) => auth_code,
Ok(None) => {
- self.audit_log("token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code));
- return Err(self.error_response("invalid_grant", "Invalid or expired authorization code"));
+ self.audit_log(
+ "token_invalid_code",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
+ return Err(self
+ .error_response("invalid_grant", "Invalid or expired authorization code"));
}
Err(_) => {
return Err(self.error_response("server_error", "Internal server error"));
@@ -238,17 +324,38 @@ impl OAuthServer {
// Validate code hasn't been used and hasn't expired
if auth_code.is_used {
- self.audit_log("token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code));
+ self.audit_log(
+ "token_code_reuse",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
return Err(self.error_response("invalid_grant", "Authorization code already used"));
}
if Utc::now() > auth_code.expires_at {
- self.audit_log("token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code));
+ self.audit_log(
+ "token_code_expired",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ Some(code),
+ );
return Err(self.error_response("invalid_grant", "Authorization code expired"));
}
if auth_code.client_id != client_id {
- self.audit_log("token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None);
+ self.audit_log(
+ "token_client_mismatch",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_grant", "Client ID mismatch"));
}
@@ -258,13 +365,22 @@ impl OAuthServer {
self.error_response("invalid_request", "Missing code_verifier for PKCE")
})?;
- let challenge_method = auth_code.code_challenge_method
+ let challenge_method = auth_code
+ .code_challenge_method
.as_ref()
.and_then(|method| CodeChallengeMethod::from_str(method).ok())
.unwrap_or(CodeChallengeMethod::Plain);
- if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) {
- self.audit_log("token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None);
+ if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method)
+ {
+ self.audit_log(
+ "token_pkce_verification_failed",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_grant", "PKCE verification failed"));
}
}
@@ -279,8 +395,14 @@ impl OAuthServer {
// Generate tokens
let token_id = Uuid::new_v4().to_string();
- let access_token = self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope, &token_id)?;
- let refresh_token = self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?;
+ let access_token = self.generate_access_token(
+ &auth_code.user_id,
+ &client_id,
+ &auth_code.scope,
+ &token_id,
+ )?;
+ let refresh_token =
+ self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?;
// Store token in database for revocation/introspection
let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes()));
@@ -311,7 +433,14 @@ impl OAuthServer {
scope: auth_code.scope,
};
- self.audit_log("token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None);
+ self.audit_log(
+ "token_success",
+ Some(&client_id),
+ Some(&auth_code.user_id),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
serde_json::to_string(&token_response)
.map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
@@ -346,7 +475,14 @@ impl OAuthServer {
{
let mut client_manager = self.client_manager.lock().unwrap();
if !client_manager.authenticate_client(&client_id, &client_secret) {
- self.audit_log("refresh_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None);
+ self.audit_log(
+ "refresh_invalid_client",
+ Some(&client_id),
+ None,
+ ip_address.as_deref(),
+ false,
+ None,
+ );
return Err(self.error_response("invalid_client", "Client authentication failed"));
}
}
@@ -354,7 +490,8 @@ impl OAuthServer {
// Validate refresh token (implementation would verify token and get user info)
// For now, return a simple refresh token response
let new_token_id = Uuid::new_v4().to_string();
- let access_token = self.generate_access_token("test_user", &client_id, &None, &new_token_id)?;
+ let access_token =
+ self.generate_access_token("test_user", &client_id, &None, &new_token_id)?;
let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?;
let token_response = TokenResponse {
@@ -365,7 +502,14 @@ impl OAuthServer {
scope: None,
};
- self.audit_log("refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None);
+ self.audit_log(
+ "refresh_success",
+ Some(&client_id),
+ Some("test_user"),
+ ip_address.as_deref(),
+ true,
+ None,
+ );
serde_json::to_string(&token_response)
.map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
@@ -543,15 +687,23 @@ impl OAuthServer {
fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> {
let db = self.database.lock().unwrap();
let count = db.increment_rate_limit(identifier, endpoint, 1)?;
-
+
if count > self.config.rate_limit_requests_per_minute as i32 {
return Err(anyhow!("Rate limit exceeded"));
}
-
+
Ok(())
}
- fn audit_log(&self, event_type: &str, client_id: Option<&str>, user_id: Option<&str>, ip_address: Option<&str>, success: bool, details: Option<&str>) {
+ fn audit_log(
+ &self,
+ event_type: &str,
+ client_id: Option<&str>,
+ user_id: Option<&str>,
+ ip_address: Option<&str>,
+ success: bool,
+ details: Option<&str>,
+ ) {
if !self.config.enable_audit_logging {
return;
}
@@ -584,16 +736,16 @@ impl OAuthServer {
// Cleanup expired data
pub fn cleanup_expired_data(&self) -> Result<()> {
let db = self.database.lock().unwrap();
-
+
// Cleanup expired authorization codes
let _ = db.cleanup_expired_codes();
-
+
// Cleanup expired tokens
let _ = db.cleanup_expired_tokens();
-
+
// Cleanup old audit logs (keep for 30 days)
let _ = db.cleanup_old_audit_logs(30);
-
+
Ok(())
}
-} \ No newline at end of file
+}
diff --git a/src/oauth/types.rs b/src/oauth/types.rs
index 0f9be5c..4f2c363 100644
--- a/src/oauth/types.rs
+++ b/src/oauth/types.rs
@@ -1,5 +1,5 @@
-use serde::{Deserialize, Serialize};
use crate::oauth::pkce::CodeChallengeMethod;
+use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {