diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-09 16:43:16 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-09 16:43:16 -0600 |
| commit | 2ef774d4c52b9fb0ae0d1717b7a3568b76bccf3d (patch) | |
| tree | fde8c20a9333e68d7e798ec5936630375da2a1f9 /src/lib.rs | |
| parent | b39a50e3ec622294cc0b6f271f1996a89f1849d6 (diff) | |
refactor: split types into separate files
Diffstat (limited to 'src/lib.rs')
| -rw-r--r-- | src/lib.rs | 450 |
1 files changed, 6 insertions, 444 deletions
@@ -1,445 +1,7 @@ -use base64::prelude::*; -use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::fs; -use std::io::BufReader; -use std::io::prelude::*; -use std::net::{TcpListener, TcpStream}; -use std::time::{SystemTime, UNIX_EPOCH}; -use url::Url; -use uuid::Uuid; +pub mod config; +pub mod http; +pub mod oauth; -#[derive(Debug, Clone)] -pub struct Config { - pub bind_addr: String, - pub issuer_url: String, - pub jwt_secret: String, -} - -impl Config { - pub fn from_env() -> Self { - let bind_addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:7878".to_string()); - let issuer_url = format!("http://{}", bind_addr); - let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| { - "your-256-bit-secret-key-here-make-it-very-long-and-secure".to_string() - }); - - Self { - bind_addr, - issuer_url, - jwt_secret, - } - } -} - -pub mod http { - use super::*; - - pub struct Server { - config: Config, - oauth_server: OAuthServer, - } - - impl Server { - pub fn new(addr: String) -> Server { - let mut config = Config::from_env(); - config.bind_addr = addr; - config.issuer_url = format!("http://{}", config.bind_addr); - - Server { - oauth_server: OAuthServer::new(&config), - config, - } - } - - pub fn start(&self) { - let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap(); - println!("Listening on {}", self.config.bind_addr); - - for stream in listener.incoming() { - match stream { - Ok(stream) => self.handle(stream), - Err(e) => eprintln!("Error accepting connection: {}", e), - } - } - } - - pub fn handle(&self, mut stream: TcpStream) { - let mut buffer = [0; 8192]; - let bytes_read = stream.read(&mut buffer).unwrap_or(0); - let request = String::from_utf8_lossy(&buffer[..bytes_read]); - - let lines: Vec<&str> = request.lines().collect(); - if lines.is_empty() { - self.send_error_response(&mut stream, 400, "Bad Request"); - return; - } - - let request_line = lines[0]; - let parts: Vec<&str> = request_line.split_whitespace().collect(); - - if parts.len() != 3 { - self.send_error_response(&mut stream, 400, "Bad Request"); - return; - } - - let method = parts[0]; - let path_and_query = parts[1]; - - // Parse URL and query parameters - let url = match Url::parse(&format!("http://localhost{}", path_and_query)) { - Ok(url) => url, - Err(_) => { - self.send_error_response(&mut stream, 400, "Bad Request"); - return; - } - }; - - let path = url.path(); - let query_params: HashMap<String, String> = url - .query_pairs() - .map(|(k, v)| (k.to_string(), v.to_string())) - .collect(); - - match (method, path) { - ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), - ("GET", "/.well-known/oauth-authorization-server") => { - self.handle_metadata(&mut stream) - } - ("GET", "/jwks") => self.handle_jwks(&mut stream), - ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params), - ("POST", "/token") => self.handle_token(&mut stream, &request), - _ => self.send_error_response(&mut stream, 404, "Not Found"), - } - } - - fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) { - match fs::read_to_string(filename) { - Ok(contents) => { - let content_type = if filename.ends_with(".json") { - "application/json" - } else { - "text/html" - }; - - let response = format!( - "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}", - content_type, - contents.len(), - contents - ); - let _ = stream.write_all(response.as_bytes()); - } - Err(_) => self.send_error_response(stream, 404, "Not Found"), - } - } - - fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) { - let response = format!( - "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", - status, - message, - message.len(), - message - ); - let _ = stream.write_all(response.as_bytes()); - } - - fn send_json_response( - &self, - stream: &mut TcpStream, - status: u16, - status_text: &str, - json: &str, - ) { - let response = format!( - "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", - status, - status_text, - json.len(), - json - ); - let _ = stream.write_all(response.as_bytes()); - } - - fn handle_metadata(&self, stream: &mut TcpStream) { - let metadata = serde_json::json!({ - "issuer": self.config.issuer_url, - "authorization_endpoint": format!("{}/authorize", self.config.issuer_url), - "token_endpoint": format!("{}/token", self.config.issuer_url), - "scopes_supported": ["openid", "profile", "email"], - "response_types_supported": ["code"], - }); - self.send_json_response(stream, 200, "OK", &metadata.to_string()); - } - - fn handle_jwks(&self, stream: &mut TcpStream) { - let jwks = self.oauth_server.get_jwks(); - self.send_json_response(stream, 200, "OK", &jwks); - } - - fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>) { - match self.oauth_server.handle_authorize(params) { - Ok(redirect_url) => { - let response = format!( - "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n", - redirect_url - ); - let _ = stream.write_all(response.as_bytes()); - } - Err(error_response) => { - self.send_json_response(stream, 400, "Bad Request", &error_response); - } - } - } - - fn handle_token(&self, stream: &mut TcpStream, request: &str) { - let body = self.extract_body(request); - let form_params = self.parse_form_data(&body); - - match self.oauth_server.handle_token(&form_params) { - Ok(token_response) => { - self.send_json_response(stream, 200, "OK", &token_response); - } - Err(error_response) => { - self.send_json_response(stream, 400, "Bad Request", &error_response); - } - } - } - - fn extract_body(&self, request: &str) -> String { - if let Some(pos) = request.find("\r\n\r\n") { - request[pos + 4..].to_string() - } else { - String::new() - } - } - - fn parse_form_data(&self, body: &str) -> HashMap<String, String> { - body.split('&') - .filter_map(|pair| { - let mut split = pair.splitn(2, '='); - if let (Some(key), Some(value)) = (split.next(), split.next()) { - Some(( - urlencoding::decode(key).unwrap_or_default().to_string(), - urlencoding::decode(value).unwrap_or_default().to_string(), - )) - } else { - None - } - }) - .collect() - } - } -} - -#[derive(Debug, Serialize, Deserialize)] -struct Claims { - sub: String, - iss: String, - aud: String, - exp: u64, - iat: u64, - #[serde(skip_serializing_if = "Option::is_none")] - scope: Option<String>, -} - -#[derive(Debug, Serialize, Deserialize)] -struct TokenResponse { - access_token: String, - token_type: String, - expires_in: u64, - #[serde(skip_serializing_if = "Option::is_none")] - refresh_token: Option<String>, - #[serde(skip_serializing_if = "Option::is_none")] - scope: Option<String>, -} - -#[derive(Debug, Serialize, Deserialize)] -struct ErrorResponse { - error: String, - #[serde(skip_serializing_if = "Option::is_none")] - error_description: Option<String>, -} - -pub struct OAuthServer { - config: Config, - encoding_key: EncodingKey, - decoding_key: DecodingKey, - auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>, -} - -#[derive(Debug, Clone)] -struct AuthCode { - client_id: String, - redirect_uri: String, - scope: Option<String>, - expires_at: u64, - user_id: String, -} - -impl OAuthServer { - pub fn new(config: &Config) -> Self { - Self { - encoding_key: EncodingKey::from_secret(config.jwt_secret.as_ref()), - decoding_key: DecodingKey::from_secret(config.jwt_secret.as_ref()), - auth_codes: std::sync::Mutex::new(HashMap::new()), - config: config.clone(), - } - } - - fn get_jwks(&self) -> String { - // For simplicity, returning empty JWKS. In production, include public key - serde_json::json!({ - "keys": [] - }) - .to_string() - } - - pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> { - // Validate required parameters - let client_id = params - .get("client_id") - .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; - - let redirect_uri = params - .get("redirect_uri") - .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; - - let response_type = params - .get("response_type") - .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; - - if response_type != "code" { - return Err(self.error_response( - "unsupported_response_type", - "Only code response type supported", - )); - } - - // Generate authorization code - let code = Uuid::new_v4().to_string(); - let expires_at = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() - + 600; // 10 minutes - - let auth_code = AuthCode { - client_id: client_id.clone(), - redirect_uri: redirect_uri.clone(), - scope: params.get("scope").cloned(), - expires_at, - user_id: "test_user".to_string(), // In production, get from authentication - }; - - { - let mut codes = self.auth_codes.lock().unwrap(); - codes.insert(code.clone(), auth_code); - } - - // Build redirect URL with authorization code - let mut redirect_url = Url::parse(redirect_uri) - .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; - - redirect_url.query_pairs_mut().append_pair("code", &code); - - if let Some(state) = params.get("state") { - redirect_url.query_pairs_mut().append_pair("state", state); - } - - Ok(redirect_url.to_string()) - } - - fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> { - let grant_type = params - .get("grant_type") - .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; - - if grant_type != "authorization_code" { - return Err(self.error_response( - "unsupported_grant_type", - "Only authorization_code grant type supported", - )); - } - - let code = params - .get("code") - .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; - - let client_id = params - .get("client_id") - .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; - - // Validate authorization code - let auth_code = { - let mut codes = self.auth_codes.lock().unwrap(); - codes.remove(code).ok_or_else(|| { - self.error_response("invalid_grant", "Invalid or expired authorization code") - })? - }; - - // Verify client_id matches - if auth_code.client_id != *client_id { - return Err(self.error_response("invalid_grant", "Client ID mismatch")); - } - - // Check expiration - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - if now > auth_code.expires_at { - return Err(self.error_response("invalid_grant", "Authorization code expired")); - } - - // Generate access token - let access_token = - self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?; - - let token_response = TokenResponse { - access_token, - token_type: "Bearer".to_string(), - expires_in: 3600, // 1 hour - refresh_token: None, - scope: auth_code.scope, - }; - - serde_json::to_string(&token_response) - .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) - } - - fn generate_access_token( - &self, - user_id: &str, - client_id: &str, - scope: &Option<String>, - ) -> Result<String, String> { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs(); - - let claims = Claims { - sub: user_id.to_string(), - iss: self.config.issuer_url.clone(), - aud: client_id.to_string(), - exp: now + 3600, // 1 hour - iat: now, - scope: scope.clone(), - }; - - encode(&Header::default(), &claims, &self.encoding_key) - .map_err(|_| self.error_response("server_error", "Failed to generate token")) - } - - fn error_response(&self, error: &str, description: &str) -> String { - let error_resp = ErrorResponse { - error: error.to_string(), - error_description: Some(description.to_string()), - }; - serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) - } -} +pub use config::Config; +pub use http::Server; +pub use oauth::OAuthServer; |
