diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 12:51:49 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 12:51:49 -0600 |
| commit | 77c185a8db0d54cb66b28b694b1671428b831595 (patch) | |
| tree | 9e671ff4a22608955370656e85eb5991b4d85d22 /src/http | |
| parent | 7c41dfe19aa0ced3b895979ca4e369067fd58da1 (diff) | |
Add full implementation
Diffstat (limited to 'src/http')
| -rw-r--r-- | src/http/mod.rs | 114 |
1 files changed, 104 insertions, 10 deletions
diff --git a/src/http/mod.rs b/src/http/mod.rs index 6ab840d..c8d485b 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -14,7 +14,7 @@ pub struct Server { impl Server { pub fn new(config: Config) -> Result<Server, Box<dyn std::error::Error>> { Ok(Server { - oauth_server: OAuthServer::new(&config)?, + oauth_server: OAuthServer::new(&config).map_err(|e| format!("Failed to create OAuth server: {}", e))?, config, }) } @@ -67,12 +67,17 @@ impl Server { .map(|(k, v)| (k.to_string(), v.to_string())) .collect(); + // Extract IP address for audit logging + let ip_address = stream.peer_addr().ok().map(|addr| addr.ip().to_string()); + match (method, path) { ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), ("GET", "/.well-known/oauth-authorization-server") => self.handle_metadata(&mut stream), ("GET", "/jwks") => self.handle_jwks(&mut stream), - ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params), - ("POST", "/token") => self.handle_token(&mut stream, &request), + ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params, ip_address), + ("POST", "/token") => self.handle_token(&mut stream, &request, ip_address), + ("POST", "/introspect") => self.handle_introspect(&mut stream, &request), + ("POST", "/revoke") => self.handle_revoke(&mut stream, &request), _ => self.send_error_response(&mut stream, 404, "Not Found"), } } @@ -93,6 +98,7 @@ impl Server { contents ); let _ = stream.write_all(response.as_bytes()); + let _ = stream.flush(); } Err(_) => self.send_error_response(stream, 404, "Not Found"), } @@ -107,6 +113,7 @@ impl Server { message ); let _ = stream.write_all(response.as_bytes()); + let _ = stream.flush(); } fn send_json_response( @@ -116,14 +123,50 @@ impl Server { status_text: &str, json: &str, ) { + let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n{}\r\n{}", status, status_text, json.len(), + security_headers, json ); let _ = stream.write_all(response.as_bytes()); + let _ = stream.flush(); + } + + fn send_empty_response(&self, stream: &mut TcpStream, status: u16, status_text: &str) { + let security_headers = self.get_security_headers(); + let response = format!( + "HTTP/1.1 {} {}\r\nContent-Length: 0\r\n{}\r\n", + status, + status_text, + security_headers + ); + let _ = stream.write_all(response.as_bytes()); + let _ = stream.flush(); + } + + fn get_security_headers(&self) -> String { + let cors_origin = if self.config.cors_allowed_origins.contains(&"*".to_string()) { + "*".to_string() + } else { + self.config.cors_allowed_origins.first().unwrap_or(&"*".to_string()).clone() + }; + + format!( + "Access-Control-Allow-Origin: {}\r\n\ + Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n\ + Access-Control-Allow-Headers: Content-Type, Authorization\r\n\ + X-Content-Type-Options: nosniff\r\n\ + X-Frame-Options: DENY\r\n\ + X-XSS-Protection: 1; mode=block\r\n\ + Strict-Transport-Security: max-age=31536000; includeSubDomains\r\n\ + Content-Security-Policy: default-src 'self'; frame-ancestors 'none'\r\n\ + Referrer-Policy: strict-origin-when-cross-origin", + cors_origin + ) } fn handle_metadata(&self, stream: &mut TcpStream) { @@ -131,8 +174,20 @@ impl Server { "issuer": self.config.issuer_url, "authorization_endpoint": format!("{}/authorize", self.config.issuer_url), "token_endpoint": format!("{}/token", self.config.issuer_url), + "jwks_uri": format!("{}/jwks", self.config.issuer_url), + "introspection_endpoint": format!("{}/introspect", self.config.issuer_url), + "revocation_endpoint": format!("{}/revoke", self.config.issuer_url), "scopes_supported": ["openid", "profile", "email"], "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"], + "code_challenge_methods_supported": ["plain", "S256"], + "response_modes_supported": ["query"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "claims_supported": ["sub", "iss", "aud", "exp", "iat", "scope"], + "introspection_endpoint_auth_methods_supported": ["client_secret_basic"], + "revocation_endpoint_auth_methods_supported": ["client_secret_basic"] }); self.send_json_response(stream, 200, "OK", &metadata.to_string()); } @@ -142,14 +197,17 @@ impl Server { self.send_json_response(stream, 200, "OK", &jwks); } - fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>) { - match self.oauth_server.handle_authorize(params) { + fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>, ip_address: Option<String>) { + match self.oauth_server.handle_authorize(params, ip_address) { Ok(redirect_url) => { + let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n", - redirect_url + "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n{}\r\n", + redirect_url, + security_headers ); let _ = stream.write_all(response.as_bytes()); + let _ = stream.flush(); } Err(error_response) => { self.send_json_response(stream, 400, "Bad Request", &error_response); @@ -157,7 +215,7 @@ impl Server { } } - fn handle_token(&self, stream: &mut TcpStream, request: &str) { + fn handle_token(&self, stream: &mut TcpStream, request: &str, ip_address: Option<String>) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); @@ -166,7 +224,7 @@ impl Server { match self .oauth_server - .handle_token(&form_params, auth_header.as_deref()) + .handle_token(&form_params, auth_header.as_deref(), ip_address) { Ok(token_response) => { self.send_json_response(stream, 200, "OK", &token_response); @@ -176,6 +234,42 @@ impl Server { } } } + + fn handle_introspect(&self, stream: &mut TcpStream, request: &str) { + let body = self.extract_body(request); + let form_params = self.parse_form_data(&body); + let auth_header = self.extract_auth_header(request); + + match self + .oauth_server + .handle_token_introspection(&form_params, auth_header.as_deref()) + { + Ok(introspection_response) => { + self.send_json_response(stream, 200, "OK", &introspection_response); + } + Err(error_response) => { + self.send_json_response(stream, 400, "Bad Request", &error_response); + } + } + } + + fn handle_revoke(&self, stream: &mut TcpStream, request: &str) { + let body = self.extract_body(request); + let form_params = self.parse_form_data(&body); + let auth_header = self.extract_auth_header(request); + + match self + .oauth_server + .handle_token_revocation(&form_params, auth_header.as_deref()) + { + Ok(_) => { + self.send_empty_response(stream, 200, "OK"); + } + Err(error_response) => { + self.send_json_response(stream, 400, "Bad Request", &error_response); + } + } + } fn extract_body(&self, request: &str) -> String { if let Some(pos) = request.find("\r\n\r\n") { |
