use crate::config::Config; use crate::oauth::OAuthServer; use std::collections::HashMap; use std::fs; use std::io::prelude::*; use std::net::{TcpListener, TcpStream}; use url::Url; pub struct Server { config: Config, oauth_server: OAuthServer, } impl Server { pub fn new(config: Config) -> Result> { Ok(Server { oauth_server: OAuthServer::new(&config)?, config, }) } pub fn start(&self) { let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap(); println!("Listening on {}", self.config.bind_addr); for stream in listener.incoming() { match stream { Ok(stream) => self.handle(stream), Err(e) => eprintln!("Error accepting connection: {}", e), } } } pub fn handle(&self, mut stream: TcpStream) { let mut buffer = [0; 8192]; let bytes_read = stream.read(&mut buffer).unwrap_or(0); let request = String::from_utf8_lossy(&buffer[..bytes_read]); let lines: Vec<&str> = request.lines().collect(); if lines.is_empty() { self.send_error_response(&mut stream, 400, "Bad Request"); return; } let request_line = lines[0]; let parts: Vec<&str> = request_line.split_whitespace().collect(); if parts.len() != 3 { self.send_error_response(&mut stream, 400, "Bad Request"); return; } let method = parts[0]; let path_and_query = parts[1]; let url = match Url::parse(&format!("http://localhost{}", path_and_query)) { Ok(url) => url, Err(_) => { self.send_error_response(&mut stream, 400, "Bad Request"); return; } }; let path = url.path(); let query_params: HashMap = url .query_pairs() .map(|(k, v)| (k.to_string(), v.to_string())) .collect(); match (method, path) { ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), ("GET", "/.well-known/oauth-authorization-server") => self.handle_metadata(&mut stream), ("GET", "/jwks") => self.handle_jwks(&mut stream), ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params), ("POST", "/token") => self.handle_token(&mut stream, &request), _ => self.send_error_response(&mut stream, 404, "Not Found"), } } fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) { match fs::read_to_string(filename) { Ok(contents) => { let content_type = if filename.ends_with(".json") { "application/json" } else { "text/html" }; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}", content_type, contents.len(), contents ); let _ = stream.write_all(response.as_bytes()); } Err(_) => self.send_error_response(stream, 404, "Not Found"), } } fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) { let response = format!( "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", status, message, message.len(), message ); let _ = stream.write_all(response.as_bytes()); } fn send_json_response( &self, stream: &mut TcpStream, status: u16, status_text: &str, json: &str, ) { let response = format!( "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", status, status_text, json.len(), json ); let _ = stream.write_all(response.as_bytes()); } fn handle_metadata(&self, stream: &mut TcpStream) { let metadata = serde_json::json!({ "issuer": self.config.issuer_url, "authorization_endpoint": format!("{}/authorize", self.config.issuer_url), "token_endpoint": format!("{}/token", self.config.issuer_url), "scopes_supported": ["openid", "profile", "email"], "response_types_supported": ["code"], }); self.send_json_response(stream, 200, "OK", &metadata.to_string()); } fn handle_jwks(&self, stream: &mut TcpStream) { let jwks = self.oauth_server.get_jwks(); self.send_json_response(stream, 200, "OK", &jwks); } fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap) { match self.oauth_server.handle_authorize(params) { Ok(redirect_url) => { let response = format!( "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n", redirect_url ); let _ = stream.write_all(response.as_bytes()); } Err(error_response) => { self.send_json_response(stream, 400, "Bad Request", &error_response); } } } fn handle_token(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); // Extract Authorization header from request let auth_header = self.extract_auth_header(request); match self .oauth_server .handle_token(&form_params, auth_header.as_deref()) { Ok(token_response) => { self.send_json_response(stream, 200, "OK", &token_response); } Err(error_response) => { self.send_json_response(stream, 400, "Bad Request", &error_response); } } } fn extract_body(&self, request: &str) -> String { if let Some(pos) = request.find("\r\n\r\n") { request[pos + 4..].to_string() } else { String::new() } } fn parse_form_data(&self, body: &str) -> HashMap { body.split('&') .filter_map(|pair| { let mut split = pair.splitn(2, '='); if let (Some(key), Some(value)) = (split.next(), split.next()) { Some(( urlencoding::decode(key).unwrap_or_default().to_string(), urlencoding::decode(value).unwrap_or_default().to_string(), )) } else { None } }) .collect() } fn extract_auth_header(&self, request: &str) -> Option { let lines: Vec<&str> = request.lines().collect(); for line in lines.iter().skip(1) { // Skip the request line if line.to_lowercase().starts_with("authorization:") { return Some(line[14..].trim().to_string()); // Skip "Authorization: " } } None } }