use base64::prelude::*; use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fs; use std::io::BufReader; use std::io::prelude::*; use std::net::{TcpListener, TcpStream}; use std::time::{SystemTime, UNIX_EPOCH}; use url::Url; use uuid::Uuid; #[derive(Debug, Clone)] pub struct Config { pub bind_addr: String, pub issuer_url: String, pub jwt_secret: String, } impl Config { pub fn from_env() -> Self { let bind_addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:7878".to_string()); let issuer_url = format!("http://{}", bind_addr); let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| { "your-256-bit-secret-key-here-make-it-very-long-and-secure".to_string() }); Self { bind_addr, issuer_url, jwt_secret, } } } pub mod http { use super::*; pub struct Server { config: Config, oauth_server: OAuthServer, } impl Server { pub fn new(addr: String) -> Server { let mut config = Config::from_env(); config.bind_addr = addr; config.issuer_url = format!("http://{}", config.bind_addr); Server { oauth_server: OAuthServer::new(&config), config, } } pub fn start(&self) { let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap(); println!("OAuth2 STS Server listening on {}", self.config.bind_addr); for stream in listener.incoming() { match stream { Ok(stream) => self.handle(stream), Err(e) => eprintln!("Error accepting connection: {}", e), } } } pub fn handle(&self, mut stream: TcpStream) { let mut buffer = [0; 8192]; let bytes_read = stream.read(&mut buffer).unwrap_or(0); let request = String::from_utf8_lossy(&buffer[..bytes_read]); let lines: Vec<&str> = request.lines().collect(); if lines.is_empty() { self.send_error_response(&mut stream, 400, "Bad Request"); return; } let request_line = lines[0]; let parts: Vec<&str> = request_line.split_whitespace().collect(); if parts.len() != 3 { self.send_error_response(&mut stream, 400, "Bad Request"); return; } let method = parts[0]; let path_and_query = parts[1]; // Parse URL and query parameters let url = match Url::parse(&format!("http://localhost{}", path_and_query)) { Ok(url) => url, Err(_) => { self.send_error_response(&mut stream, 400, "Bad Request"); return; } }; let path = url.path(); let query_params: HashMap = url .query_pairs() .map(|(k, v)| (k.to_string(), v.to_string())) .collect(); match (method, path) { ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"), ("GET", "/.well-known/oauth-authorization-server") => { self.handle_metadata(&mut stream) } ("GET", "/jwks") => self.handle_jwks(&mut stream), ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params), ("POST", "/token") => self.handle_token(&mut stream, &request), _ => self.send_error_response(&mut stream, 404, "Not Found"), } } fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) { match fs::read_to_string(filename) { Ok(contents) => { let content_type = if filename.ends_with(".json") { "application/json" } else { "text/html" }; let response = format!( "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}", content_type, contents.len(), contents ); let _ = stream.write_all(response.as_bytes()); } Err(_) => self.send_error_response(stream, 404, "Not Found"), } } fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) { let response = format!( "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}", status, message, message.len(), message ); let _ = stream.write_all(response.as_bytes()); } fn send_json_response( &self, stream: &mut TcpStream, status: u16, status_text: &str, json: &str, ) { let response = format!( "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", status, status_text, json.len(), json ); let _ = stream.write_all(response.as_bytes()); } fn handle_metadata(&self, stream: &mut TcpStream) { let metadata = serde_json::json!({ "issuer": self.config.issuer_url, "authorization_endpoint": format!("{}/authorize", self.config.issuer_url), "token_endpoint": format!("{}/token", self.config.issuer_url), "jwks_uri": format!("{}/jwks", self.config.issuer_url), "scopes_supported": ["openid", "profile", "email"], "response_types_supported": ["code"], "response_modes_supported": ["query"], "grant_types_supported": ["authorization_code"], "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": ["RS256"], "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"] }); self.send_json_response(stream, 200, "OK", &metadata.to_string()); } fn handle_jwks(&self, stream: &mut TcpStream) { let jwks = self.oauth_server.get_jwks(); self.send_json_response(stream, 200, "OK", &jwks); } fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap) { match self.oauth_server.handle_authorize(params) { Ok(redirect_url) => { let response = format!( "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n", redirect_url ); let _ = stream.write_all(response.as_bytes()); } Err(error_response) => { self.send_json_response(stream, 400, "Bad Request", &error_response); } } } fn handle_token(&self, stream: &mut TcpStream, request: &str) { let body = self.extract_body(request); let form_params = self.parse_form_data(&body); match self.oauth_server.handle_token(&form_params) { Ok(token_response) => { self.send_json_response(stream, 200, "OK", &token_response); } Err(error_response) => { self.send_json_response(stream, 400, "Bad Request", &error_response); } } } fn extract_body(&self, request: &str) -> String { if let Some(pos) = request.find("\r\n\r\n") { request[pos + 4..].to_string() } else { String::new() } } fn parse_form_data(&self, body: &str) -> HashMap { body.split('&') .filter_map(|pair| { let mut split = pair.splitn(2, '='); if let (Some(key), Some(value)) = (split.next(), split.next()) { Some(( urlencoding::decode(key).unwrap_or_default().to_string(), urlencoding::decode(value).unwrap_or_default().to_string(), )) } else { None } }) .collect() } } } #[derive(Debug, Serialize, Deserialize)] struct Claims { sub: String, iss: String, aud: String, exp: u64, iat: u64, #[serde(skip_serializing_if = "Option::is_none")] scope: Option, } #[derive(Debug, Serialize, Deserialize)] struct TokenResponse { access_token: String, token_type: String, expires_in: u64, #[serde(skip_serializing_if = "Option::is_none")] refresh_token: Option, #[serde(skip_serializing_if = "Option::is_none")] scope: Option, } #[derive(Debug, Serialize, Deserialize)] struct ErrorResponse { error: String, #[serde(skip_serializing_if = "Option::is_none")] error_description: Option, } pub struct OAuthServer { config: Config, encoding_key: EncodingKey, decoding_key: DecodingKey, auth_codes: std::sync::Mutex>, } #[derive(Debug, Clone)] struct AuthCode { client_id: String, redirect_uri: String, scope: Option, expires_at: u64, user_id: String, } impl OAuthServer { pub fn new(config: &Config) -> Self { Self { encoding_key: EncodingKey::from_secret(config.jwt_secret.as_ref()), decoding_key: DecodingKey::from_secret(config.jwt_secret.as_ref()), auth_codes: std::sync::Mutex::new(HashMap::new()), config: config.clone(), } } fn get_jwks(&self) -> String { // For simplicity, returning empty JWKS. In production, include public key serde_json::json!({ "keys": [] }) .to_string() } pub fn handle_authorize(&self, params: &HashMap) -> Result { // Validate required parameters let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; let redirect_uri = params .get("redirect_uri") .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?; let response_type = params .get("response_type") .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; if response_type != "code" { return Err(self.error_response( "unsupported_response_type", "Only code response type supported", )); } // Generate authorization code let code = Uuid::new_v4().to_string(); let expires_at = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs() + 600; // 10 minutes let auth_code = AuthCode { client_id: client_id.clone(), redirect_uri: redirect_uri.clone(), scope: params.get("scope").cloned(), expires_at, user_id: "test_user".to_string(), // In production, get from authentication }; { let mut codes = self.auth_codes.lock().unwrap(); codes.insert(code.clone(), auth_code); } // Build redirect URL with authorization code let mut redirect_url = Url::parse(redirect_uri) .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?; redirect_url.query_pairs_mut().append_pair("code", &code); if let Some(state) = params.get("state") { redirect_url.query_pairs_mut().append_pair("state", state); } Ok(redirect_url.to_string()) } fn handle_token(&self, params: &HashMap) -> Result { let grant_type = params .get("grant_type") .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; if grant_type != "authorization_code" { return Err(self.error_response( "unsupported_grant_type", "Only authorization_code grant type supported", )); } let code = params .get("code") .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; let client_id = params .get("client_id") .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; // Validate authorization code let auth_code = { let mut codes = self.auth_codes.lock().unwrap(); codes.remove(code).ok_or_else(|| { self.error_response("invalid_grant", "Invalid or expired authorization code") })? }; // Verify client_id matches if auth_code.client_id != *client_id { return Err(self.error_response("invalid_grant", "Client ID mismatch")); } // Check expiration let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); if now > auth_code.expires_at { return Err(self.error_response("invalid_grant", "Authorization code expired")); } // Generate access token let access_token = self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?; let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, // 1 hour refresh_token: None, scope: auth_code.scope, }; serde_json::to_string(&token_response) .map_err(|_| self.error_response("server_error", "Failed to serialize token response")) } fn generate_access_token( &self, user_id: &str, client_id: &str, scope: &Option, ) -> Result { let now = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(); let claims = Claims { sub: user_id.to_string(), iss: self.config.issuer_url.clone(), aud: client_id.to_string(), exp: now + 3600, // 1 hour iat: now, scope: scope.clone(), }; encode(&Header::default(), &claims, &self.encoding_key) .map_err(|_| self.error_response("server_error", "Failed to generate token")) } fn error_response(&self, error: &str, description: &str) -> String { let error_resp = ErrorResponse { error: error.to_string(), error_description: Some(description.to_string()), }; serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string()) } }