diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 15:12:59 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 15:12:59 -0600 |
| commit | 4435ee26b79648e92d0f172e42f9e6629e955505 (patch) | |
| tree | 0720fd07c879a58672fcfcb2e45ed1161430f039 | |
| parent | 39c67cfc6c74bf4b26ba455f3adda1241aea35ea (diff) | |
chore: rustfmt and include Connection: header in responses
| -rw-r--r-- | Cargo.toml | 4 | ||||
| -rw-r--r-- | src/bin/debug.rs | 36 | ||||
| -rw-r--r-- | src/bin/migrate.rs | 9 | ||||
| -rw-r--r-- | src/http/mod.rs | 53 | ||||
| -rw-r--r-- | src/oauth/mod.rs | 4 | ||||
| -rw-r--r-- | src/oauth/pkce.rs | 57 | ||||
| -rw-r--r-- | src/oauth/server.rs | 240 | ||||
| -rw-r--r-- | src/oauth/types.rs | 2 |
8 files changed, 311 insertions, 94 deletions
@@ -11,6 +11,10 @@ path = "src/main.rs" name = "migrate" path = "src/bin/migrate.rs" +[[bin]] +name = "debug" +path = "src/bin/debug.rs" + [dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/src/bin/debug.rs b/src/bin/debug.rs new file mode 100644 index 0000000..6d80848 --- /dev/null +++ b/src/bin/debug.rs @@ -0,0 +1,36 @@ +fn main() { + let config = sts::Config::from_env(); + println!("Config loaded: {}", config.bind_addr); + let server = sts::http::Server::new(config.clone()); + println!("Server result: {:?}", server.is_ok()); + + if let Ok(server) = server { + let oauth_server = &server.oauth_server; + let jwks = oauth_server.get_jwks(); + println!("JWKS length: {}", jwks.len()); + println!( + "JWKS: {}", + if jwks.len() > 100 { + &jwks[..100] + } else { + &jwks + } + ); + } + + let metadata = serde_json::json!({ + "issuer": config.issuer_url, + "authorization_endpoint": format!("{}/authorize", config.issuer_url), + "token_endpoint": format!("{}/token", config.issuer_url) + }); + let metadata_str = metadata.to_string(); + println!("Metadata length: {}", metadata_str.len()); + println!( + "Metadata: {}", + if metadata_str.len() > 100 { + &metadata_str[..100] + } else { + &metadata_str + } + ); +} diff --git a/src/bin/migrate.rs b/src/bin/migrate.rs index 9afdbf0..9a0bab9 100644 --- a/src/bin/migrate.rs +++ b/src/bin/migrate.rs @@ -1,11 +1,11 @@ use anyhow::Result; use rusqlite::Connection; -use sts::{Config, MigrationRunner}; use std::env; +use sts::{Config, MigrationRunner}; fn main() -> Result<()> { let args: Vec<String> = env::args().collect(); - + if args.len() < 2 { print_usage(); return Ok(()); @@ -29,7 +29,8 @@ fn main() -> Result<()> { eprintln!("Usage: cargo run --bin migrate rollback <version>"); return Ok(()); } - let version: i32 = args[2].parse() + let version: i32 = args[2] + .parse() .map_err(|_| anyhow::anyhow!("Invalid version number: {}", args[2]))?; runner.rollback_to_version(version)?; } @@ -58,4 +59,4 @@ fn print_usage() { println!(" cargo run --bin migrate up"); println!(" cargo run --bin migrate status"); println!(" cargo run --bin migrate rollback 0"); -}
\ No newline at end of file +} diff --git a/src/http/mod.rs b/src/http/mod.rs index c8d485b..1bc7951 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -8,13 +8,14 @@ use url::Url; pub struct Server { config: Config, - oauth_server: OAuthServer, + pub oauth_server: OAuthServer, } impl Server { pub fn new(config: Config) -> Result<Server, Box<dyn std::error::Error>> { Ok(Server { - oauth_server: OAuthServer::new(&config).map_err(|e| format!("Failed to create OAuth server: {}", e))?, + oauth_server: OAuthServer::new(&config) + .map_err(|e| format!("Failed to create OAuth server: {}", e))?, config, }) } @@ -69,7 +70,7 @@ impl Server { // Extract IP address for audit logging let ip_address = stream.peer_addr().ok().map(|addr| addr.ip().to_string()); - + match (method, path) { ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), ("GET", "/.well-known/oauth-authorization-server") => self.handle_metadata(&mut stream), @@ -92,13 +93,13 @@ impl Server { }; let response = format!( - "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}", + "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", content_type, contents.len(), contents ); let _ = stream.write_all(response.as_bytes()); - let _ = stream.flush(); + let _ = stream.flush(); } Err(_) => self.send_error_response(stream, 404, "Not Found"), } @@ -106,7 +107,7 @@ impl Server { fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) { let response = format!( - "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", + "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", status, message, message.len(), @@ -123,38 +124,38 @@ impl Server { status_text: &str, json: &str, ) { - let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n{}\r\n{}", + "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", status, status_text, json.len(), - security_headers, json ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); } - + fn send_empty_response(&self, stream: &mut TcpStream, status: u16, status_text: &str) { let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 {} {}\r\nContent-Length: 0\r\n{}\r\n", - status, - status_text, - security_headers + "HTTP/1.1 {} {}\r\nContent-Length: 0\r\nConnection: close\r\n{}\r\n", + status, status_text, security_headers ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); } - + fn get_security_headers(&self) -> String { let cors_origin = if self.config.cors_allowed_origins.contains(&"*".to_string()) { "*".to_string() } else { - self.config.cors_allowed_origins.first().unwrap_or(&"*".to_string()).clone() + self.config + .cors_allowed_origins + .first() + .unwrap_or(&"*".to_string()) + .clone() }; - + format!( "Access-Control-Allow-Origin: {}\r\n\ Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n\ @@ -197,17 +198,21 @@ impl Server { self.send_json_response(stream, 200, "OK", &jwks); } - fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>, ip_address: Option<String>) { + fn handle_authorize( + &self, + stream: &mut TcpStream, + params: &HashMap<String, String>, + ip_address: Option<String>, + ) { match self.oauth_server.handle_authorize(params, ip_address) { Ok(redirect_url) => { let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n{}\r\n", - redirect_url, - security_headers + "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\nConnection: close\r\n{}\r\n", + redirect_url, security_headers ); let _ = stream.write_all(response.as_bytes()); - let _ = stream.flush(); + let _ = stream.flush(); } Err(error_response) => { self.send_json_response(stream, 400, "Bad Request", &error_response); @@ -234,7 +239,7 @@ impl Server { } } } - + fn handle_introspect(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); @@ -252,7 +257,7 @@ impl Server { } } } - + fn handle_revoke(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs index 7fd0d7b..4b18bb3 100644 --- a/src/oauth/mod.rs +++ b/src/oauth/mod.rs @@ -2,6 +2,8 @@ pub mod pkce; pub mod server; pub mod types; -pub use pkce::{CodeChallengeMethod, verify_code_challenge, generate_code_verifier, generate_code_challenge}; +pub use pkce::{ + generate_code_challenge, generate_code_verifier, verify_code_challenge, CodeChallengeMethod, +}; pub use server::OAuthServer; pub use types::{AuthCode, Claims, ErrorResponse, TokenResponse}; diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs index c943844..406d364 100644 --- a/src/oauth/pkce.rs +++ b/src/oauth/pkce.rs @@ -1,6 +1,6 @@ -use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; -use sha2::{Digest, Sha256}; use anyhow::{anyhow, Result}; +use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; +use sha2::{Digest, Sha256}; #[derive(Debug, Clone, PartialEq)] pub enum CodeChallengeMethod { @@ -32,13 +32,16 @@ pub fn verify_code_challenge( ) -> Result<bool> { // Validate code verifier format (RFC 7636 Section 4.1) if code_verifier.len() < 43 || code_verifier.len() > 128 { - return Err(anyhow!("Code verifier length must be between 43 and 128 characters")); + return Err(anyhow!( + "Code verifier length must be between 43 and 128 characters" + )); } // Code verifier must only contain unreserved characters - if !code_verifier.chars().all(|c| { - c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' - }) { + if !code_verifier + .chars() + .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~') + { return Err(anyhow!("Code verifier contains invalid characters")); } @@ -57,7 +60,7 @@ pub fn verify_code_challenge( pub fn generate_code_verifier() -> String { use rand::Rng; let mut rng = rand::thread_rng(); - + // Generate 32 random bytes and encode them let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect(); URL_SAFE_NO_PAD.encode(&bytes) @@ -80,8 +83,14 @@ mod tests { #[test] fn test_code_challenge_method_from_str() { - assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain); - assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256); + assert_eq!( + CodeChallengeMethod::from_str("plain").unwrap(), + CodeChallengeMethod::Plain + ); + assert_eq!( + CodeChallengeMethod::from_str("S256").unwrap(), + CodeChallengeMethod::S256 + ); assert!(CodeChallengeMethod::from_str("invalid").is_err()); } @@ -95,7 +104,7 @@ mod tests { fn test_verify_code_challenge_plain() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; - + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap()); assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap()); } @@ -104,7 +113,7 @@ mod tests { fn test_verify_code_challenge_s256() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; - + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap()); assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap()); } @@ -113,9 +122,14 @@ mod tests { fn test_verify_code_challenge_invalid_verifier() { // Too short assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err()); - + // Invalid characters - assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err()); + assert!(verify_code_challenge( + "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", + "challenge", + &CodeChallengeMethod::Plain + ) + .is_err()); } #[test] @@ -123,7 +137,7 @@ mod tests { let verifier = generate_code_verifier(); assert!(verifier.len() >= 43); assert!(verifier.len() <= 128); - + // Should only contain valid characters assert!(verifier.chars().all(|c| { c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' @@ -133,24 +147,27 @@ mod tests { #[test] fn test_generate_code_challenge() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; - + let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain); assert_eq!(plain_challenge, verifier); - + let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256); - assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); + assert_eq!( + s256_challenge, + "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + ); } #[test] fn test_round_trip() { let verifier = generate_code_verifier(); - + // Test with S256 let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256); assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap()); - + // Test with Plain let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain); assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap()); } -}
\ No newline at end of file +} diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 7552f00..7fd8b9c 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,9 +1,9 @@ use crate::clients::{parse_basic_auth, ClientManager}; use crate::config::Config; -use crate::database::{Database, DbAuthCode, DbAccessToken, DbAuditLog}; +use crate::database::{Database, DbAccessToken, DbAuditLog, DbAuthCode}; use crate::keys::KeyManager; -use crate::oauth::pkce::{CodeChallengeMethod, verify_code_challenge}; -use crate::oauth::types::{Claims, ErrorResponse, TokenResponse, TokenIntrospectionResponse}; +use crate::oauth::pkce::{verify_code_challenge, CodeChallengeMethod}; +use crate::oauth::types::{Claims, ErrorResponse, TokenIntrospectionResponse, TokenResponse}; use anyhow::{anyhow, Result}; use chrono::{Duration, Utc}; use jsonwebtoken::{encode, Algorithm, Header}; @@ -43,7 +43,11 @@ impl OAuthServer { } } - pub fn handle_authorize(&self, params: &HashMap<String, String>, ip_address: Option<String>) -> Result<String, String> { + pub fn handle_authorize( + &self, + params: &HashMap<String, String>, + ip_address: Option<String>, + ) -> Result<String, String> { let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; @@ -58,7 +62,14 @@ impl OAuthServer { // Rate limiting check if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/authorize") { - self.audit_log("authorize_rate_limited", Some(client_id), None, ip_address.as_deref(), false, Some(&e.to_string())); + self.audit_log( + "authorize_rate_limited", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } @@ -68,7 +79,14 @@ impl OAuthServer { match client_manager.get_client_from_db(client_id) { Ok(Some(client)) => client, Ok(None) => { - self.audit_log("authorize_invalid_client", Some(client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "authorize_invalid_client", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_client", "Invalid client_id")); } Err(_) => { @@ -81,7 +99,14 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) { - self.audit_log("authorize_invalid_redirect_uri", Some(client_id), None, ip_address.as_deref(), false, Some(redirect_uri)); + self.audit_log( + "authorize_invalid_redirect_uri", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(redirect_uri), + ); return Err(self.error_response("invalid_request", "Invalid redirect_uri")); } } @@ -91,13 +116,27 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.is_scope_valid(client_id, &scope) { - self.audit_log("authorize_invalid_scope", Some(client_id), None, ip_address.as_deref(), false, scope.as_deref()); + self.audit_log( + "authorize_invalid_scope", + Some(client_id), + None, + ip_address.as_deref(), + false, + scope.as_deref(), + ); return Err(self.error_response("invalid_scope", "Invalid scope")); } } if response_type != "code" { - self.audit_log("authorize_unsupported_response_type", Some(client_id), None, ip_address.as_deref(), false, Some(response_type)); + self.audit_log( + "authorize_unsupported_response_type", + Some(client_id), + None, + ip_address.as_deref(), + false, + Some(response_type), + ); return Err(self.error_response( "unsupported_response_type", "Only code response type supported", @@ -106,14 +145,22 @@ impl OAuthServer { // PKCE validation (RFC 7636) let code_challenge = params.get("code_challenge"); - let code_challenge_method = params.get("code_challenge_method") + let code_challenge_method = params + .get("code_challenge_method") .map(|method| CodeChallengeMethod::from_str(method)) .transpose() .map_err(|_| self.error_response("invalid_request", "Invalid code_challenge_method"))?; // For public clients, PKCE is required if client.client_id.starts_with("public_") && code_challenge.is_none() { - self.audit_log("authorize_missing_pkce", Some(client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "authorize_missing_pkce", + Some(client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_request", "PKCE required for public clients")); } @@ -131,14 +178,18 @@ impl OAuthServer { created_at: Utc::now(), is_used: false, code_challenge: code_challenge.cloned(), - code_challenge_method: code_challenge_method.as_ref().map(|m| m.as_str().to_string()), + code_challenge_method: code_challenge_method + .as_ref() + .map(|m| m.as_str().to_string()), }; // Save to database { let db = self.database.lock().unwrap(); if let Err(_) = db.create_auth_code(&db_auth_code) { - return Err(self.error_response("server_error", "Failed to create authorization code")); + return Err( + self.error_response("server_error", "Failed to create authorization code") + ); } } @@ -151,7 +202,14 @@ impl OAuthServer { redirect_url.query_pairs_mut().append_pair("state", state); } - self.audit_log("authorize_success", Some(client_id), Some("test_user"), ip_address.as_deref(), true, None); + self.audit_log( + "authorize_success", + Some(client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); Ok(redirect_url.to_string()) } @@ -167,14 +225,20 @@ impl OAuthServer { .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; match grant_type.as_str() { - "authorization_code" => self.handle_authorization_code_grant(params, auth_header, ip_address), + "authorization_code" => { + self.handle_authorization_code_grant(params, auth_header, ip_address) + } "refresh_token" => self.handle_refresh_token_grant(params, auth_header, ip_address), _ => { - self.audit_log("token_unsupported_grant_type", None, None, ip_address.as_deref(), false, Some(grant_type)); - Err(self.error_response( - "unsupported_grant_type", - "Unsupported grant type", - )) + self.audit_log( + "token_unsupported_grant_type", + None, + None, + ip_address.as_deref(), + false, + Some(grant_type), + ); + Err(self.error_response("unsupported_grant_type", "Unsupported grant type")) } } } @@ -208,7 +272,14 @@ impl OAuthServer { // Rate limiting check if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") { - self.audit_log("token_rate_limited", Some(&client_id), None, ip_address.as_deref(), false, Some(&e.to_string())); + self.audit_log( + "token_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); } @@ -216,7 +287,14 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { - self.audit_log("token_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "token_invalid_client", + Some(&client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_client", "Client authentication failed")); } } @@ -227,8 +305,16 @@ impl OAuthServer { match db.get_auth_code(code) { Ok(Some(auth_code)) => auth_code, Ok(None) => { - self.audit_log("token_invalid_code", Some(&client_id), None, ip_address.as_deref(), false, Some(code)); - return Err(self.error_response("invalid_grant", "Invalid or expired authorization code")); + self.audit_log( + "token_invalid_code", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(code), + ); + return Err(self + .error_response("invalid_grant", "Invalid or expired authorization code")); } Err(_) => { return Err(self.error_response("server_error", "Internal server error")); @@ -238,17 +324,38 @@ impl OAuthServer { // Validate code hasn't been used and hasn't expired if auth_code.is_used { - self.audit_log("token_code_reuse", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code)); + self.audit_log( + "token_code_reuse", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); return Err(self.error_response("invalid_grant", "Authorization code already used")); } if Utc::now() > auth_code.expires_at { - self.audit_log("token_code_expired", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, Some(code)); + self.audit_log( + "token_code_expired", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + Some(code), + ); return Err(self.error_response("invalid_grant", "Authorization code expired")); } if auth_code.client_id != client_id { - self.audit_log("token_client_mismatch", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None); + self.audit_log( + "token_client_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_grant", "Client ID mismatch")); } @@ -258,13 +365,22 @@ impl OAuthServer { self.error_response("invalid_request", "Missing code_verifier for PKCE") })?; - let challenge_method = auth_code.code_challenge_method + let challenge_method = auth_code + .code_challenge_method .as_ref() .and_then(|method| CodeChallengeMethod::from_str(method).ok()) .unwrap_or(CodeChallengeMethod::Plain); - if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) { - self.audit_log("token_pkce_verification_failed", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), false, None); + if let Err(_) = verify_code_challenge(code_verifier, code_challenge, &challenge_method) + { + self.audit_log( + "token_pkce_verification_failed", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_grant", "PKCE verification failed")); } } @@ -279,8 +395,14 @@ impl OAuthServer { // Generate tokens let token_id = Uuid::new_v4().to_string(); - let access_token = self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope, &token_id)?; - let refresh_token = self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; + let access_token = self.generate_access_token( + &auth_code.user_id, + &client_id, + &auth_code.scope, + &token_id, + )?; + let refresh_token = + self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; // Store token in database for revocation/introspection let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); @@ -311,7 +433,14 @@ impl OAuthServer { scope: auth_code.scope, }; - self.audit_log("token_success", Some(&client_id), Some(&auth_code.user_id), ip_address.as_deref(), true, None); + self.audit_log( + "token_success", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + true, + None, + ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) @@ -346,7 +475,14 @@ impl OAuthServer { { let mut client_manager = self.client_manager.lock().unwrap(); if !client_manager.authenticate_client(&client_id, &client_secret) { - self.audit_log("refresh_invalid_client", Some(&client_id), None, ip_address.as_deref(), false, None); + self.audit_log( + "refresh_invalid_client", + Some(&client_id), + None, + ip_address.as_deref(), + false, + None, + ); return Err(self.error_response("invalid_client", "Client authentication failed")); } } @@ -354,7 +490,8 @@ impl OAuthServer { // Validate refresh token (implementation would verify token and get user info) // For now, return a simple refresh token response let new_token_id = Uuid::new_v4().to_string(); - let access_token = self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; + let access_token = + self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?; let token_response = TokenResponse { @@ -365,7 +502,14 @@ impl OAuthServer { scope: None, }; - self.audit_log("refresh_success", Some(&client_id), Some("test_user"), ip_address.as_deref(), true, None); + self.audit_log( + "refresh_success", + Some(&client_id), + Some("test_user"), + ip_address.as_deref(), + true, + None, + ); serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) @@ -543,15 +687,23 @@ impl OAuthServer { fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { let db = self.database.lock().unwrap(); let count = db.increment_rate_limit(identifier, endpoint, 1)?; - + if count > self.config.rate_limit_requests_per_minute as i32 { return Err(anyhow!("Rate limit exceeded")); } - + Ok(()) } - fn audit_log(&self, event_type: &str, client_id: Option<&str>, user_id: Option<&str>, ip_address: Option<&str>, success: bool, details: Option<&str>) { + fn audit_log( + &self, + event_type: &str, + client_id: Option<&str>, + user_id: Option<&str>, + ip_address: Option<&str>, + success: bool, + details: Option<&str>, + ) { if !self.config.enable_audit_logging { return; } @@ -584,16 +736,16 @@ impl OAuthServer { // Cleanup expired data pub fn cleanup_expired_data(&self) -> Result<()> { let db = self.database.lock().unwrap(); - + // Cleanup expired authorization codes let _ = db.cleanup_expired_codes(); - + // Cleanup expired tokens let _ = db.cleanup_expired_tokens(); - + // Cleanup old audit logs (keep for 30 days) let _ = db.cleanup_old_audit_logs(30); - + Ok(()) } -}
\ No newline at end of file +} diff --git a/src/oauth/types.rs b/src/oauth/types.rs index 0f9be5c..4f2c363 100644 --- a/src/oauth/types.rs +++ b/src/oauth/types.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use crate::oauth::pkce::CodeChallengeMethod; +use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] pub struct Claims { |
