diff options
Diffstat (limited to 'src/http/mod.rs')
| -rw-r--r-- | src/http/mod.rs | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/src/http/mod.rs b/src/http/mod.rs new file mode 100644 index 0000000..a133f09 --- /dev/null +++ b/src/http/mod.rs @@ -0,0 +1,203 @@ +use crate::config::Config; +use crate::oauth::OAuthServer; +use std::collections::HashMap; +use std::fs; +use std::io::prelude::*; +use std::net::{TcpListener, TcpStream}; +use url::Url; + +pub struct Server { + config: Config, + oauth_server: OAuthServer, +} + +impl Server { + pub fn new(addr: String) -> Server { + let mut config = Config::from_env(); + config.bind_addr = addr; + config.issuer_url = format!("http://{}", config.bind_addr); + + Server { + oauth_server: OAuthServer::new(&config), + config, + } + } + + pub fn start(&self) { + let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap(); + println!("Listening on {}", self.config.bind_addr); + + for stream in listener.incoming() { + match stream { + Ok(stream) => self.handle(stream), + Err(e) => eprintln!("Error accepting connection: {}", e), + } + } + } + + pub fn handle(&self, mut stream: TcpStream) { + let mut buffer = [0; 8192]; + let bytes_read = stream.read(&mut buffer).unwrap_or(0); + let request = String::from_utf8_lossy(&buffer[..bytes_read]); + + let lines: Vec<&str> = request.lines().collect(); + if lines.is_empty() { + self.send_error_response(&mut stream, 400, "Bad Request"); + return; + } + + let request_line = lines[0]; + let parts: Vec<&str> = request_line.split_whitespace().collect(); + + if parts.len() != 3 { + self.send_error_response(&mut stream, 400, "Bad Request"); + return; + } + + let method = parts[0]; + let path_and_query = parts[1]; + + let url = match Url::parse(&format!("http://localhost{}", path_and_query)) { + Ok(url) => url, + Err(_) => { + self.send_error_response(&mut stream, 400, "Bad Request"); + return; + } + }; + + let path = url.path(); + let query_params: HashMap<String, String> = url + .query_pairs() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + + match (method, path) { + ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), + ("GET", "/.well-known/oauth-authorization-server") => { + self.handle_metadata(&mut stream) + } + ("GET", "/jwks") => self.handle_jwks(&mut stream), + ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params), + ("POST", "/token") => self.handle_token(&mut stream, &request), + _ => self.send_error_response(&mut stream, 404, "Not Found"), + } + } + + fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) { + match fs::read_to_string(filename) { + Ok(contents) => { + let content_type = if filename.ends_with(".json") { + "application/json" + } else { + "text/html" + }; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}", + content_type, + contents.len(), + contents + ); + let _ = stream.write_all(response.as_bytes()); + } + Err(_) => self.send_error_response(stream, 404, "Not Found"), + } + } + + fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) { + let response = format!( + "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", + status, + message, + message.len(), + message + ); + let _ = stream.write_all(response.as_bytes()); + } + + fn send_json_response( + &self, + stream: &mut TcpStream, + status: u16, + status_text: &str, + json: &str, + ) { + let response = format!( + "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + status, + status_text, + json.len(), + json + ); + let _ = stream.write_all(response.as_bytes()); + } + + fn handle_metadata(&self, stream: &mut TcpStream) { + let metadata = serde_json::json!({ + "issuer": self.config.issuer_url, + "authorization_endpoint": format!("{}/authorize", self.config.issuer_url), + "token_endpoint": format!("{}/token", self.config.issuer_url), + "scopes_supported": ["openid", "profile", "email"], + "response_types_supported": ["code"], + }); + self.send_json_response(stream, 200, "OK", &metadata.to_string()); + } + + fn handle_jwks(&self, stream: &mut TcpStream) { + let jwks = self.oauth_server.get_jwks(); + self.send_json_response(stream, 200, "OK", &jwks); + } + + fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>) { + match self.oauth_server.handle_authorize(params) { + Ok(redirect_url) => { + let response = format!( + "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n", + redirect_url + ); + let _ = stream.write_all(response.as_bytes()); + } + Err(error_response) => { + self.send_json_response(stream, 400, "Bad Request", &error_response); + } + } + } + + fn handle_token(&self, stream: &mut TcpStream, request: &str) { + let body = self.extract_body(request); + let form_params = self.parse_form_data(&body); + + match self.oauth_server.handle_token(&form_params) { + Ok(token_response) => { + self.send_json_response(stream, 200, "OK", &token_response); + } + Err(error_response) => { + self.send_json_response(stream, 400, "Bad Request", &error_response); + } + } + } + + fn extract_body(&self, request: &str) -> String { + if let Some(pos) = request.find("\r\n\r\n") { + request[pos + 4..].to_string() + } else { + String::new() + } + } + + fn parse_form_data(&self, body: &str) -> HashMap<String, String> { + body.split('&') + .filter_map(|pair| { + let mut split = pair.splitn(2, '='); + if let (Some(key), Some(value)) = (split.next(), split.next()) { + Some(( + urlencoding::decode(key).unwrap_or_default().to_string(), + urlencoding::decode(value).unwrap_or_default().to_string(), + )) + } else { + None + } + }) + .collect() + } +}
\ No newline at end of file |
