summaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-09 16:43:16 -0600
committermo khan <mo@mokhan.ca>2025-06-09 16:43:16 -0600
commit2ef774d4c52b9fb0ae0d1717b7a3568b76bccf3d (patch)
treefde8c20a9333e68d7e798ec5936630375da2a1f9 /src/lib.rs
parentb39a50e3ec622294cc0b6f271f1996a89f1849d6 (diff)
refactor: split types into separate files
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs450
1 files changed, 6 insertions, 444 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 1231503..1563317 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,445 +1,7 @@
-use base64::prelude::*;
-use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
-use serde::{Deserialize, Serialize};
-use std::collections::HashMap;
-use std::fs;
-use std::io::BufReader;
-use std::io::prelude::*;
-use std::net::{TcpListener, TcpStream};
-use std::time::{SystemTime, UNIX_EPOCH};
-use url::Url;
-use uuid::Uuid;
+pub mod config;
+pub mod http;
+pub mod oauth;
-#[derive(Debug, Clone)]
-pub struct Config {
- pub bind_addr: String,
- pub issuer_url: String,
- pub jwt_secret: String,
-}
-
-impl Config {
- pub fn from_env() -> Self {
- let bind_addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:7878".to_string());
- let issuer_url = format!("http://{}", bind_addr);
- let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| {
- "your-256-bit-secret-key-here-make-it-very-long-and-secure".to_string()
- });
-
- Self {
- bind_addr,
- issuer_url,
- jwt_secret,
- }
- }
-}
-
-pub mod http {
- use super::*;
-
- pub struct Server {
- config: Config,
- oauth_server: OAuthServer,
- }
-
- impl Server {
- pub fn new(addr: String) -> Server {
- let mut config = Config::from_env();
- config.bind_addr = addr;
- config.issuer_url = format!("http://{}", config.bind_addr);
-
- Server {
- oauth_server: OAuthServer::new(&config),
- config,
- }
- }
-
- pub fn start(&self) {
- let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap();
- println!("Listening on {}", self.config.bind_addr);
-
- for stream in listener.incoming() {
- match stream {
- Ok(stream) => self.handle(stream),
- Err(e) => eprintln!("Error accepting connection: {}", e),
- }
- }
- }
-
- pub fn handle(&self, mut stream: TcpStream) {
- let mut buffer = [0; 8192];
- let bytes_read = stream.read(&mut buffer).unwrap_or(0);
- let request = String::from_utf8_lossy(&buffer[..bytes_read]);
-
- let lines: Vec<&str> = request.lines().collect();
- if lines.is_empty() {
- self.send_error_response(&mut stream, 400, "Bad Request");
- return;
- }
-
- let request_line = lines[0];
- let parts: Vec<&str> = request_line.split_whitespace().collect();
-
- if parts.len() != 3 {
- self.send_error_response(&mut stream, 400, "Bad Request");
- return;
- }
-
- let method = parts[0];
- let path_and_query = parts[1];
-
- // Parse URL and query parameters
- let url = match Url::parse(&format!("http://localhost{}", path_and_query)) {
- Ok(url) => url,
- Err(_) => {
- self.send_error_response(&mut stream, 400, "Bad Request");
- return;
- }
- };
-
- let path = url.path();
- let query_params: HashMap<String, String> = url
- .query_pairs()
- .map(|(k, v)| (k.to_string(), v.to_string()))
- .collect();
-
- match (method, path) {
- ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"),
- ("GET", "/.well-known/oauth-authorization-server") => {
- self.handle_metadata(&mut stream)
- }
- ("GET", "/jwks") => self.handle_jwks(&mut stream),
- ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params),
- ("POST", "/token") => self.handle_token(&mut stream, &request),
- _ => self.send_error_response(&mut stream, 404, "Not Found"),
- }
- }
-
- fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) {
- match fs::read_to_string(filename) {
- Ok(contents) => {
- let content_type = if filename.ends_with(".json") {
- "application/json"
- } else {
- "text/html"
- };
-
- let response = format!(
- "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}",
- content_type,
- contents.len(),
- contents
- );
- let _ = stream.write_all(response.as_bytes());
- }
- Err(_) => self.send_error_response(stream, 404, "Not Found"),
- }
- }
-
- fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) {
- let response = format!(
- "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}",
- status,
- message,
- message.len(),
- message
- );
- let _ = stream.write_all(response.as_bytes());
- }
-
- fn send_json_response(
- &self,
- stream: &mut TcpStream,
- status: u16,
- status_text: &str,
- json: &str,
- ) {
- let response = format!(
- "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
- status,
- status_text,
- json.len(),
- json
- );
- let _ = stream.write_all(response.as_bytes());
- }
-
- fn handle_metadata(&self, stream: &mut TcpStream) {
- let metadata = serde_json::json!({
- "issuer": self.config.issuer_url,
- "authorization_endpoint": format!("{}/authorize", self.config.issuer_url),
- "token_endpoint": format!("{}/token", self.config.issuer_url),
- "scopes_supported": ["openid", "profile", "email"],
- "response_types_supported": ["code"],
- });
- self.send_json_response(stream, 200, "OK", &metadata.to_string());
- }
-
- fn handle_jwks(&self, stream: &mut TcpStream) {
- let jwks = self.oauth_server.get_jwks();
- self.send_json_response(stream, 200, "OK", &jwks);
- }
-
- fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>) {
- match self.oauth_server.handle_authorize(params) {
- Ok(redirect_url) => {
- let response = format!(
- "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n",
- redirect_url
- );
- let _ = stream.write_all(response.as_bytes());
- }
- Err(error_response) => {
- self.send_json_response(stream, 400, "Bad Request", &error_response);
- }
- }
- }
-
- fn handle_token(&self, stream: &mut TcpStream, request: &str) {
- let body = self.extract_body(request);
- let form_params = self.parse_form_data(&body);
-
- match self.oauth_server.handle_token(&form_params) {
- Ok(token_response) => {
- self.send_json_response(stream, 200, "OK", &token_response);
- }
- Err(error_response) => {
- self.send_json_response(stream, 400, "Bad Request", &error_response);
- }
- }
- }
-
- fn extract_body(&self, request: &str) -> String {
- if let Some(pos) = request.find("\r\n\r\n") {
- request[pos + 4..].to_string()
- } else {
- String::new()
- }
- }
-
- fn parse_form_data(&self, body: &str) -> HashMap<String, String> {
- body.split('&')
- .filter_map(|pair| {
- let mut split = pair.splitn(2, '=');
- if let (Some(key), Some(value)) = (split.next(), split.next()) {
- Some((
- urlencoding::decode(key).unwrap_or_default().to_string(),
- urlencoding::decode(value).unwrap_or_default().to_string(),
- ))
- } else {
- None
- }
- })
- .collect()
- }
- }
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct Claims {
- sub: String,
- iss: String,
- aud: String,
- exp: u64,
- iat: u64,
- #[serde(skip_serializing_if = "Option::is_none")]
- scope: Option<String>,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct TokenResponse {
- access_token: String,
- token_type: String,
- expires_in: u64,
- #[serde(skip_serializing_if = "Option::is_none")]
- refresh_token: Option<String>,
- #[serde(skip_serializing_if = "Option::is_none")]
- scope: Option<String>,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-struct ErrorResponse {
- error: String,
- #[serde(skip_serializing_if = "Option::is_none")]
- error_description: Option<String>,
-}
-
-pub struct OAuthServer {
- config: Config,
- encoding_key: EncodingKey,
- decoding_key: DecodingKey,
- auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>,
-}
-
-#[derive(Debug, Clone)]
-struct AuthCode {
- client_id: String,
- redirect_uri: String,
- scope: Option<String>,
- expires_at: u64,
- user_id: String,
-}
-
-impl OAuthServer {
- pub fn new(config: &Config) -> Self {
- Self {
- encoding_key: EncodingKey::from_secret(config.jwt_secret.as_ref()),
- decoding_key: DecodingKey::from_secret(config.jwt_secret.as_ref()),
- auth_codes: std::sync::Mutex::new(HashMap::new()),
- config: config.clone(),
- }
- }
-
- fn get_jwks(&self) -> String {
- // For simplicity, returning empty JWKS. In production, include public key
- serde_json::json!({
- "keys": []
- })
- .to_string()
- }
-
- pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> {
- // Validate required parameters
- let client_id = params
- .get("client_id")
- .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
-
- let redirect_uri = params
- .get("redirect_uri")
- .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?;
-
- let response_type = params
- .get("response_type")
- .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?;
-
- if response_type != "code" {
- return Err(self.error_response(
- "unsupported_response_type",
- "Only code response type supported",
- ));
- }
-
- // Generate authorization code
- let code = Uuid::new_v4().to_string();
- let expires_at = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap()
- .as_secs()
- + 600; // 10 minutes
-
- let auth_code = AuthCode {
- client_id: client_id.clone(),
- redirect_uri: redirect_uri.clone(),
- scope: params.get("scope").cloned(),
- expires_at,
- user_id: "test_user".to_string(), // In production, get from authentication
- };
-
- {
- let mut codes = self.auth_codes.lock().unwrap();
- codes.insert(code.clone(), auth_code);
- }
-
- // Build redirect URL with authorization code
- let mut redirect_url = Url::parse(redirect_uri)
- .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?;
-
- redirect_url.query_pairs_mut().append_pair("code", &code);
-
- if let Some(state) = params.get("state") {
- redirect_url.query_pairs_mut().append_pair("state", state);
- }
-
- Ok(redirect_url.to_string())
- }
-
- fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> {
- let grant_type = params
- .get("grant_type")
- .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
-
- if grant_type != "authorization_code" {
- return Err(self.error_response(
- "unsupported_grant_type",
- "Only authorization_code grant type supported",
- ));
- }
-
- let code = params
- .get("code")
- .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?;
-
- let client_id = params
- .get("client_id")
- .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
-
- // Validate authorization code
- let auth_code = {
- let mut codes = self.auth_codes.lock().unwrap();
- codes.remove(code).ok_or_else(|| {
- self.error_response("invalid_grant", "Invalid or expired authorization code")
- })?
- };
-
- // Verify client_id matches
- if auth_code.client_id != *client_id {
- return Err(self.error_response("invalid_grant", "Client ID mismatch"));
- }
-
- // Check expiration
- let now = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap()
- .as_secs();
-
- if now > auth_code.expires_at {
- return Err(self.error_response("invalid_grant", "Authorization code expired"));
- }
-
- // Generate access token
- let access_token =
- self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?;
-
- let token_response = TokenResponse {
- access_token,
- token_type: "Bearer".to_string(),
- expires_in: 3600, // 1 hour
- refresh_token: None,
- scope: auth_code.scope,
- };
-
- serde_json::to_string(&token_response)
- .map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
- }
-
- fn generate_access_token(
- &self,
- user_id: &str,
- client_id: &str,
- scope: &Option<String>,
- ) -> Result<String, String> {
- let now = SystemTime::now()
- .duration_since(UNIX_EPOCH)
- .unwrap()
- .as_secs();
-
- let claims = Claims {
- sub: user_id.to_string(),
- iss: self.config.issuer_url.clone(),
- aud: client_id.to_string(),
- exp: now + 3600, // 1 hour
- iat: now,
- scope: scope.clone(),
- };
-
- encode(&Header::default(), &claims, &self.encoding_key)
- .map_err(|_| self.error_response("server_error", "Failed to generate token"))
- }
-
- fn error_response(&self, error: &str, description: &str) -> String {
- let error_resp = ErrorResponse {
- error: error.to_string(),
- error_description: Some(description.to_string()),
- };
- serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string())
- }
-}
+pub use config::Config;
+pub use http::Server;
+pub use oauth::OAuthServer;