diff options
Diffstat (limited to 'src/lib.rs')
| -rw-r--r-- | src/lib.rs | 451 |
1 files changed, 430 insertions, 21 deletions
@@ -1,42 +1,451 @@ +use base64::prelude::*; +use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fs; +use std::io::BufReader; +use std::io::prelude::*; +use std::net::{TcpListener, TcpStream}; +use std::time::{SystemTime, UNIX_EPOCH}; +use url::Url; +use uuid::Uuid; + +#[derive(Debug, Clone)] +pub struct Config { + pub bind_addr: String, + pub issuer_url: String, + pub jwt_secret: String, +} + +impl Config { + pub fn from_env() -> Self { + let bind_addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:7878".to_string()); + let issuer_url = format!("http://{}", bind_addr); + let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| { + "your-256-bit-secret-key-here-make-it-very-long-and-secure".to_string() + }); + + Self { + bind_addr, + issuer_url, + jwt_secret, + } + } +} + pub mod http { - use std::fs; - use std::io::BufReader; - use std::io::prelude::*; - use std::net::TcpListener; - use std::net::TcpStream; + use super::*; + pub struct Server { - addr: String, + config: Config, + oauth_server: OAuthServer, } impl Server { pub fn new(addr: String) -> Server { - Server { addr } + let mut config = Config::from_env(); + config.bind_addr = addr; + config.issuer_url = format!("http://{}", config.bind_addr); + + Server { + oauth_server: OAuthServer::new(&config), + config, + } } pub fn start(&self) { - let listener = TcpListener::bind(self.addr.clone()).unwrap(); - for next_stream in listener.incoming() { - self.handle(next_stream.unwrap()); + let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap(); + println!("OAuth2 STS Server listening on {}", self.config.bind_addr); + + for stream in listener.incoming() { + match stream { + Ok(stream) => self.handle(stream), + Err(e) => eprintln!("Error accepting connection: {}", e), + } } } pub fn handle(&self, mut stream: TcpStream) { - let io = BufReader::new(&stream); - let request_line = io.lines().next().unwrap().unwrap(); + let mut buffer = [0; 8192]; + let bytes_read = stream.read(&mut buffer).unwrap_or(0); + let request = String::from_utf8_lossy(&buffer[..bytes_read]); - let (status_line, filename) = match &request_line[..] { - "GET / HTTP/1.1" => ("HTTP/1.1 200 OK", "./public/index.html"), - "GET /.well-known/oauth-authorization-server HTTP/1.1" => { - ("HTTP/1.1 200 OK", "./public/metadata.json") + let lines: Vec<&str> = request.lines().collect(); + if lines.is_empty() { + self.send_error_response(&mut stream, 400, "Bad Request"); + return; + } + + let request_line = lines[0]; + let parts: Vec<&str> = request_line.split_whitespace().collect(); + + if parts.len() != 3 { + self.send_error_response(&mut stream, 400, "Bad Request"); + return; + } + + let method = parts[0]; + let path_and_query = parts[1]; + + // Parse URL and query parameters + let url = match Url::parse(&format!("http://localhost{}", path_and_query)) { + Ok(url) => url, + Err(_) => { + self.send_error_response(&mut stream, 400, "Bad Request"); + return; } - _ => ("HTTP/1.1 404 NOT FOUND", "./public/404.html"), }; - let contents = fs::read_to_string(filename).unwrap(); - let length = contents.len(); - let response = format!("{status_line}\r\nContent-Length: {length}\r\n\r\n{contents}"); + let path = url.path(); + let query_params: HashMap<String, String> = url + .query_pairs() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + + match (method, path) { + ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), + ("GET", "/.well-known/oauth-authorization-server") => { + self.handle_metadata(&mut stream) + } + ("GET", "/jwks") => self.handle_jwks(&mut stream), + ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params), + ("POST", "/token") => self.handle_token(&mut stream, &request), + _ => self.send_error_response(&mut stream, 404, "Not Found"), + } + } + + fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) { + match fs::read_to_string(filename) { + Ok(contents) => { + let content_type = if filename.ends_with(".json") { + "application/json" + } else { + "text/html" + }; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}", + content_type, + contents.len(), + contents + ); + let _ = stream.write_all(response.as_bytes()); + } + Err(_) => self.send_error_response(stream, 404, "Not Found"), + } + } + + fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) { + let response = format!( + "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", + status, + message, + message.len(), + message + ); + let _ = stream.write_all(response.as_bytes()); + } + + fn send_json_response( + &self, + stream: &mut TcpStream, + status: u16, + status_text: &str, + json: &str, + ) { + let response = format!( + "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + status, + status_text, + json.len(), + json + ); + let _ = stream.write_all(response.as_bytes()); + } + + fn handle_metadata(&self, stream: &mut TcpStream) { + let metadata = serde_json::json!({ + "issuer": self.config.issuer_url, + "authorization_endpoint": format!("{}/authorize", self.config.issuer_url), + "token_endpoint": format!("{}/token", self.config.issuer_url), + "jwks_uri": format!("{}/jwks", self.config.issuer_url), + "scopes_supported": ["openid", "profile", "email"], + "response_types_supported": ["code"], + "response_modes_supported": ["query"], + "grant_types_supported": ["authorization_code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"] + }); + self.send_json_response(stream, 200, "OK", &metadata.to_string()); + } + + fn handle_jwks(&self, stream: &mut TcpStream) { + let jwks = self.oauth_server.get_jwks(); + self.send_json_response(stream, 200, "OK", &jwks); + } + + fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>) { + match self.oauth_server.handle_authorize(params) { + Ok(redirect_url) => { + let response = format!( + "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n", + redirect_url + ); + let _ = stream.write_all(response.as_bytes()); + } + Err(error_response) => { + self.send_json_response(stream, 400, "Bad Request", &error_response); + } + } + } + + fn handle_token(&self, stream: &mut TcpStream, request: &str) { + let body = self.extract_body(request); + let form_params = self.parse_form_data(&body); + + match self.oauth_server.handle_token(&form_params) { + Ok(token_response) => { + self.send_json_response(stream, 200, "OK", &token_response); + } + Err(error_response) => { + self.send_json_response(stream, 400, "Bad Request", &error_response); + } + } + } + + fn extract_body(&self, request: &str) -> String { + if let Some(pos) = request.find("\r\n\r\n") { + request[pos + 4..].to_string() + } else { + String::new() + } + } + + fn parse_form_data(&self, body: &str) -> HashMap<String, String> { + body.split('&') + .filter_map(|pair| { + let mut split = pair.splitn(2, '='); + if let (Some(key), Some(value)) = (split.next(), split.next()) { + Some(( + urlencoding::decode(key).unwrap_or_default().to_string(), + urlencoding::decode(value).unwrap_or_default().to_string(), + )) + } else { + None + } + }) + .collect() + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct Claims { + sub: String, + iss: String, + aud: String, + exp: u64, + iat: u64, + #[serde(skip_serializing_if = "Option::is_none")] + scope: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +struct TokenResponse { + access_token: String, + token_type: String, + expires_in: u64, + #[serde(skip_serializing_if = "Option::is_none")] + refresh_token: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + scope: Option<String>, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ErrorResponse { + error: String, + #[serde(skip_serializing_if = "Option::is_none")] + error_description: Option<String>, +} + +pub struct OAuthServer { + config: Config, + encoding_key: EncodingKey, + decoding_key: DecodingKey, + auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>, +} + +#[derive(Debug, Clone)] +struct AuthCode { + client_id: String, + redirect_uri: String, + scope: Option<String>, + expires_at: u64, + user_id: String, +} - stream.write_all(response.as_bytes()).unwrap(); +impl OAuthServer { + pub fn new(config: &Config) -> Self { + Self { + encoding_key: EncodingKey::from_secret(config.jwt_secret.as_ref()), + decoding_key: DecodingKey::from_secret(config.jwt_secret.as_ref()), + auth_codes: std::sync::Mutex::new(HashMap::new()), + config: config.clone(), } } + + fn get_jwks(&self) -> String { + // For simplicity, returning empty JWKS. In production, include public key + serde_json::json!({ + "keys": [] + }) + .to_string() + } + + pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> { + // Validate required parameters + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + let redirect_uri = params + .get("redirect_uri") + .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; + + let response_type = params + .get("response_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + + if response_type != "code" { + return Err(self.error_response( + "unsupported_response_type", + "Only code response type supported", + )); + } + + // Generate authorization code + let code = Uuid::new_v4().to_string(); + let expires_at = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + + 600; // 10 minutes + + let auth_code = AuthCode { + client_id: client_id.clone(), + redirect_uri: redirect_uri.clone(), + scope: params.get("scope").cloned(), + expires_at, + user_id: "test_user".to_string(), // In production, get from authentication + }; + + { + let mut codes = self.auth_codes.lock().unwrap(); + codes.insert(code.clone(), auth_code); + } + + // Build redirect URL with authorization code + let mut redirect_url = Url::parse(redirect_uri) + .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; + + redirect_url.query_pairs_mut().append_pair("code", &code); + + if let Some(state) = params.get("state") { + redirect_url.query_pairs_mut().append_pair("state", state); + } + + Ok(redirect_url.to_string()) + } + + fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> { + let grant_type = params + .get("grant_type") + .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; + + if grant_type != "authorization_code" { + return Err(self.error_response( + "unsupported_grant_type", + "Only authorization_code grant type supported", + )); + } + + let code = params + .get("code") + .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; + + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + + // Validate authorization code + let auth_code = { + let mut codes = self.auth_codes.lock().unwrap(); + codes.remove(code).ok_or_else(|| { + self.error_response("invalid_grant", "Invalid or expired authorization code") + })? + }; + + // Verify client_id matches + if auth_code.client_id != *client_id { + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + // Check expiration + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + if now > auth_code.expires_at { + return Err(self.error_response("invalid_grant", "Authorization code expired")); + } + + // Generate access token + let access_token = + self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?; + + let token_response = TokenResponse { + access_token, + token_type: "Bearer".to_string(), + expires_in: 3600, // 1 hour + refresh_token: None, + scope: auth_code.scope, + }; + + serde_json::to_string(&token_response) + .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) + } + + fn generate_access_token( + &self, + user_id: &str, + client_id: &str, + scope: &Option<String>, + ) -> Result<String, String> { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let claims = Claims { + sub: user_id.to_string(), + iss: self.config.issuer_url.clone(), + aud: client_id.to_string(), + exp: now + 3600, // 1 hour + iat: now, + scope: scope.clone(), + }; + + encode(&Header::default(), &claims, &self.encoding_key) + .map_err(|_| self.error_response("server_error", "Failed to generate token")) + } + + fn error_response(&self, error: &str, description: &str) -> String { + let error_resp = ErrorResponse { + error: error.to_string(), + error_description: Some(description.to_string()), + }; + serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) + } } |
