summaryrefslogtreecommitdiff
path: root/src/http
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-09 16:43:16 -0600
committermo khan <mo@mokhan.ca>2025-06-09 16:43:16 -0600
commit2ef774d4c52b9fb0ae0d1717b7a3568b76bccf3d (patch)
treefde8c20a9333e68d7e798ec5936630375da2a1f9 /src/http
parentb39a50e3ec622294cc0b6f271f1996a89f1849d6 (diff)
refactor: split types into separate files
Diffstat (limited to 'src/http')
-rw-r--r--src/http/mod.rs203
1 files changed, 203 insertions, 0 deletions
diff --git a/src/http/mod.rs b/src/http/mod.rs
new file mode 100644
index 0000000..a133f09
--- /dev/null
+++ b/src/http/mod.rs
@@ -0,0 +1,203 @@
+use crate::config::Config;
+use crate::oauth::OAuthServer;
+use std::collections::HashMap;
+use std::fs;
+use std::io::prelude::*;
+use std::net::{TcpListener, TcpStream};
+use url::Url;
+
+pub struct Server {
+ config: Config,
+ oauth_server: OAuthServer,
+}
+
+impl Server {
+ pub fn new(addr: String) -> Server {
+ let mut config = Config::from_env();
+ config.bind_addr = addr;
+ config.issuer_url = format!("http://{}", config.bind_addr);
+
+ Server {
+ oauth_server: OAuthServer::new(&config),
+ config,
+ }
+ }
+
+ pub fn start(&self) {
+ let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap();
+ println!("Listening on {}", self.config.bind_addr);
+
+ for stream in listener.incoming() {
+ match stream {
+ Ok(stream) => self.handle(stream),
+ Err(e) => eprintln!("Error accepting connection: {}", e),
+ }
+ }
+ }
+
+ pub fn handle(&self, mut stream: TcpStream) {
+ let mut buffer = [0; 8192];
+ let bytes_read = stream.read(&mut buffer).unwrap_or(0);
+ let request = String::from_utf8_lossy(&buffer[..bytes_read]);
+
+ let lines: Vec<&str> = request.lines().collect();
+ if lines.is_empty() {
+ self.send_error_response(&mut stream, 400, "Bad Request");
+ return;
+ }
+
+ let request_line = lines[0];
+ let parts: Vec<&str> = request_line.split_whitespace().collect();
+
+ if parts.len() != 3 {
+ self.send_error_response(&mut stream, 400, "Bad Request");
+ return;
+ }
+
+ let method = parts[0];
+ let path_and_query = parts[1];
+
+ let url = match Url::parse(&format!("http://localhost{}", path_and_query)) {
+ Ok(url) => url,
+ Err(_) => {
+ self.send_error_response(&mut stream, 400, "Bad Request");
+ return;
+ }
+ };
+
+ let path = url.path();
+ let query_params: HashMap<String, String> = url
+ .query_pairs()
+ .map(|(k, v)| (k.to_string(), v.to_string()))
+ .collect();
+
+ match (method, path) {
+ ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"),
+ ("GET", "/.well-known/oauth-authorization-server") => {
+ self.handle_metadata(&mut stream)
+ }
+ ("GET", "/jwks") => self.handle_jwks(&mut stream),
+ ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params),
+ ("POST", "/token") => self.handle_token(&mut stream, &request),
+ _ => self.send_error_response(&mut stream, 404, "Not Found"),
+ }
+ }
+
+ fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) {
+ match fs::read_to_string(filename) {
+ Ok(contents) => {
+ let content_type = if filename.ends_with(".json") {
+ "application/json"
+ } else {
+ "text/html"
+ };
+
+ let response = format!(
+ "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}",
+ content_type,
+ contents.len(),
+ contents
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+ Err(_) => self.send_error_response(stream, 404, "Not Found"),
+ }
+ }
+
+ fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) {
+ let response = format!(
+ "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}",
+ status,
+ message,
+ message.len(),
+ message
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+
+ fn send_json_response(
+ &self,
+ stream: &mut TcpStream,
+ status: u16,
+ status_text: &str,
+ json: &str,
+ ) {
+ let response = format!(
+ "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
+ status,
+ status_text,
+ json.len(),
+ json
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+
+ fn handle_metadata(&self, stream: &mut TcpStream) {
+ let metadata = serde_json::json!({
+ "issuer": self.config.issuer_url,
+ "authorization_endpoint": format!("{}/authorize", self.config.issuer_url),
+ "token_endpoint": format!("{}/token", self.config.issuer_url),
+ "scopes_supported": ["openid", "profile", "email"],
+ "response_types_supported": ["code"],
+ });
+ self.send_json_response(stream, 200, "OK", &metadata.to_string());
+ }
+
+ fn handle_jwks(&self, stream: &mut TcpStream) {
+ let jwks = self.oauth_server.get_jwks();
+ self.send_json_response(stream, 200, "OK", &jwks);
+ }
+
+ fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>) {
+ match self.oauth_server.handle_authorize(params) {
+ Ok(redirect_url) => {
+ let response = format!(
+ "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n",
+ redirect_url
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+ Err(error_response) => {
+ self.send_json_response(stream, 400, "Bad Request", &error_response);
+ }
+ }
+ }
+
+ fn handle_token(&self, stream: &mut TcpStream, request: &str) {
+ let body = self.extract_body(request);
+ let form_params = self.parse_form_data(&body);
+
+ match self.oauth_server.handle_token(&form_params) {
+ Ok(token_response) => {
+ self.send_json_response(stream, 200, "OK", &token_response);
+ }
+ Err(error_response) => {
+ self.send_json_response(stream, 400, "Bad Request", &error_response);
+ }
+ }
+ }
+
+ fn extract_body(&self, request: &str) -> String {
+ if let Some(pos) = request.find("\r\n\r\n") {
+ request[pos + 4..].to_string()
+ } else {
+ String::new()
+ }
+ }
+
+ fn parse_form_data(&self, body: &str) -> HashMap<String, String> {
+ body.split('&')
+ .filter_map(|pair| {
+ let mut split = pair.splitn(2, '=');
+ if let (Some(key), Some(value)) = (split.next(), split.next()) {
+ Some((
+ urlencoding::decode(key).unwrap_or_default().to_string(),
+ urlencoding::decode(value).unwrap_or_default().to_string(),
+ ))
+ } else {
+ None
+ }
+ })
+ .collect()
+ }
+} \ No newline at end of file