diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 15:12:59 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 15:12:59 -0600 |
| commit | 4435ee26b79648e92d0f172e42f9e6629e955505 (patch) | |
| tree | 0720fd07c879a58672fcfcb2e45ed1161430f039 /src/http/mod.rs | |
| parent | 39c67cfc6c74bf4b26ba455f3adda1241aea35ea (diff) | |
chore: rustfmt and include Connection: header in responses
Diffstat (limited to 'src/http/mod.rs')
| -rw-r--r-- | src/http/mod.rs | 53 |
1 files changed, 29 insertions, 24 deletions
diff --git a/src/http/mod.rs b/src/http/mod.rs index c8d485b..1bc7951 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -8,13 +8,14 @@ use url::Url; pub struct Server { config: Config, - oauth_server: OAuthServer, + pub oauth_server: OAuthServer, } impl Server { pub fn new(config: Config) -> Result<Server, Box<dyn std::error::Error>> { Ok(Server { - oauth_server: OAuthServer::new(&config).map_err(|e| format!("Failed to create OAuth server: {}", e))?, + oauth_server: OAuthServer::new(&config) + .map_err(|e| format!("Failed to create OAuth server: {}", e))?, config, }) } @@ -69,7 +70,7 @@ impl Server { // Extract IP address for audit logging let ip_address = stream.peer_addr().ok().map(|addr| addr.ip().to_string()); - + match (method, path) { ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), ("GET", "/.well-known/oauth-authorization-server") => self.handle_metadata(&mut stream), @@ -92,13 +93,13 @@ impl Server { }; let response = format!( - "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}", + "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", content_type, contents.len(), contents ); let _ = stream.write_all(response.as_bytes()); - let _ = stream.flush(); + let _ = stream.flush(); } Err(_) => self.send_error_response(stream, 404, "Not Found"), } @@ -106,7 +107,7 @@ impl Server { fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) { let response = format!( - "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", + "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", status, message, message.len(), @@ -123,38 +124,38 @@ impl Server { status_text: &str, json: &str, ) { - let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n{}\r\n{}", + "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", status, status_text, json.len(), - security_headers, json ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); } - + fn send_empty_response(&self, stream: &mut TcpStream, status: u16, status_text: &str) { let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 {} {}\r\nContent-Length: 0\r\n{}\r\n", - status, - status_text, - security_headers + "HTTP/1.1 {} {}\r\nContent-Length: 0\r\nConnection: close\r\n{}\r\n", + status, status_text, security_headers ); let _ = stream.write_all(response.as_bytes()); let _ = stream.flush(); } - + fn get_security_headers(&self) -> String { let cors_origin = if self.config.cors_allowed_origins.contains(&"*".to_string()) { "*".to_string() } else { - self.config.cors_allowed_origins.first().unwrap_or(&"*".to_string()).clone() + self.config + .cors_allowed_origins + .first() + .unwrap_or(&"*".to_string()) + .clone() }; - + format!( "Access-Control-Allow-Origin: {}\r\n\ Access-Control-Allow-Methods: GET, POST, OPTIONS\r\n\ @@ -197,17 +198,21 @@ impl Server { self.send_json_response(stream, 200, "OK", &jwks); } - fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>, ip_address: Option<String>) { + fn handle_authorize( + &self, + stream: &mut TcpStream, + params: &HashMap<String, String>, + ip_address: Option<String>, + ) { match self.oauth_server.handle_authorize(params, ip_address) { Ok(redirect_url) => { let security_headers = self.get_security_headers(); let response = format!( - "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n{}\r\n", - redirect_url, - security_headers + "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\nConnection: close\r\n{}\r\n", + redirect_url, security_headers ); let _ = stream.write_all(response.as_bytes()); - let _ = stream.flush(); + let _ = stream.flush(); } Err(error_response) => { self.send_json_response(stream, 400, "Bad Request", &error_response); @@ -234,7 +239,7 @@ impl Server { } } } - + fn handle_introspect(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); @@ -252,7 +257,7 @@ impl Server { } } } - + fn handle_revoke(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); |
