summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-06 15:49:19 -0600
committermo khan <mo@mokhan.ca>2025-06-06 15:49:19 -0600
commit14c7a0e3ebf77451662bbbac1915facdec0bca3f (patch)
tree9473c21c06d425be2395398ec2a851c695c92a79 /src
parent463c259bd41f20d5811b028e8045f3de3effe097 (diff)
refactor: try vibe coding with claude
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs451
-rw-r--r--src/main.rs52
2 files changed, 481 insertions, 22 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 947e852..f23c4a2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,42 +1,451 @@
+use base64::prelude::*;
+use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
+use serde::{Deserialize, Serialize};
+use std::collections::HashMap;
+use std::fs;
+use std::io::BufReader;
+use std::io::prelude::*;
+use std::net::{TcpListener, TcpStream};
+use std::time::{SystemTime, UNIX_EPOCH};
+use url::Url;
+use uuid::Uuid;
+
+#[derive(Debug, Clone)]
+pub struct Config {
+ pub bind_addr: String,
+ pub issuer_url: String,
+ pub jwt_secret: String,
+}
+
+impl Config {
+ pub fn from_env() -> Self {
+ let bind_addr = std::env::var("BIND_ADDR").unwrap_or_else(|_| "127.0.0.1:7878".to_string());
+ let issuer_url = format!("http://{}", bind_addr);
+ let jwt_secret = std::env::var("JWT_SECRET").unwrap_or_else(|_| {
+ "your-256-bit-secret-key-here-make-it-very-long-and-secure".to_string()
+ });
+
+ Self {
+ bind_addr,
+ issuer_url,
+ jwt_secret,
+ }
+ }
+}
+
pub mod http {
- use std::fs;
- use std::io::BufReader;
- use std::io::prelude::*;
- use std::net::TcpListener;
- use std::net::TcpStream;
+ use super::*;
+
pub struct Server {
- addr: String,
+ config: Config,
+ oauth_server: OAuthServer,
}
impl Server {
pub fn new(addr: String) -> Server {
- Server { addr }
+ let mut config = Config::from_env();
+ config.bind_addr = addr;
+ config.issuer_url = format!("http://{}", config.bind_addr);
+
+ Server {
+ oauth_server: OAuthServer::new(&config),
+ config,
+ }
}
pub fn start(&self) {
- let listener = TcpListener::bind(self.addr.clone()).unwrap();
- for next_stream in listener.incoming() {
- self.handle(next_stream.unwrap());
+ let listener = TcpListener::bind(self.config.bind_addr.clone()).unwrap();
+ println!("OAuth2 STS Server listening on {}", self.config.bind_addr);
+
+ for stream in listener.incoming() {
+ match stream {
+ Ok(stream) => self.handle(stream),
+ Err(e) => eprintln!("Error accepting connection: {}", e),
+ }
}
}
pub fn handle(&self, mut stream: TcpStream) {
- let io = BufReader::new(&stream);
- let request_line = io.lines().next().unwrap().unwrap();
+ let mut buffer = [0; 8192];
+ let bytes_read = stream.read(&mut buffer).unwrap_or(0);
+ let request = String::from_utf8_lossy(&buffer[..bytes_read]);
- let (status_line, filename) = match &request_line[..] {
- "GET / HTTP/1.1" => ("HTTP/1.1 200 OK", "./public/index.html"),
- "GET /.well-known/oauth-authorization-server HTTP/1.1" => {
- ("HTTP/1.1 200 OK", "./public/metadata.json")
+ let lines: Vec<&str> = request.lines().collect();
+ if lines.is_empty() {
+ self.send_error_response(&mut stream, 400, "Bad Request");
+ return;
+ }
+
+ let request_line = lines[0];
+ let parts: Vec<&str> = request_line.split_whitespace().collect();
+
+ if parts.len() != 3 {
+ self.send_error_response(&mut stream, 400, "Bad Request");
+ return;
+ }
+
+ let method = parts[0];
+ let path_and_query = parts[1];
+
+ // Parse URL and query parameters
+ let url = match Url::parse(&format!("http://localhost{}", path_and_query)) {
+ Ok(url) => url,
+ Err(_) => {
+ self.send_error_response(&mut stream, 400, "Bad Request");
+ return;
}
- _ => ("HTTP/1.1 404 NOT FOUND", "./public/404.html"),
};
- let contents = fs::read_to_string(filename).unwrap();
- let length = contents.len();
- let response = format!("{status_line}\r\nContent-Length: {length}\r\n\r\n{contents}");
+ let path = url.path();
+ let query_params: HashMap<String, String> = url
+ .query_pairs()
+ .map(|(k, v)| (k.to_string(), v.to_string()))
+ .collect();
+
+ match (method, path) {
+ ("GET", "/") => self.serve_static_file(&mut stream, "./public/index.html"),
+ ("GET", "/.well-known/oauth-authorization-server") => {
+ self.handle_metadata(&mut stream)
+ }
+ ("GET", "/jwks") => self.handle_jwks(&mut stream),
+ ("GET", "/authorize") => self.handle_authorize(&mut stream, &query_params),
+ ("POST", "/token") => self.handle_token(&mut stream, &request),
+ _ => self.send_error_response(&mut stream, 404, "Not Found"),
+ }
+ }
+
+ fn serve_static_file(&self, stream: &mut TcpStream, filename: &str) {
+ match fs::read_to_string(filename) {
+ Ok(contents) => {
+ let content_type = if filename.ends_with(".json") {
+ "application/json"
+ } else {
+ "text/html"
+ };
+
+ let response = format!(
+ "HTTP/1.1 200 OK\r\nContent-Type: {}\r\nContent-Length: {}\r\n\r\n{}",
+ content_type,
+ contents.len(),
+ contents
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+ Err(_) => self.send_error_response(stream, 404, "Not Found"),
+ }
+ }
+
+ fn send_error_response(&self, stream: &mut TcpStream, status: u16, message: &str) {
+ let response = format!(
+ "HTTP/1.1 {} {}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\n\r\n{}",
+ status,
+ message,
+ message.len(),
+ message
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+
+ fn send_json_response(
+ &self,
+ stream: &mut TcpStream,
+ status: u16,
+ status_text: &str,
+ json: &str,
+ ) {
+ let response = format!(
+ "HTTP/1.1 {} {}\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
+ status,
+ status_text,
+ json.len(),
+ json
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+
+ fn handle_metadata(&self, stream: &mut TcpStream) {
+ let metadata = serde_json::json!({
+ "issuer": self.config.issuer_url,
+ "authorization_endpoint": format!("{}/authorize", self.config.issuer_url),
+ "token_endpoint": format!("{}/token", self.config.issuer_url),
+ "jwks_uri": format!("{}/jwks", self.config.issuer_url),
+ "scopes_supported": ["openid", "profile", "email"],
+ "response_types_supported": ["code"],
+ "response_modes_supported": ["query"],
+ "grant_types_supported": ["authorization_code"],
+ "subject_types_supported": ["public"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ "token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"]
+ });
+ self.send_json_response(stream, 200, "OK", &metadata.to_string());
+ }
+
+ fn handle_jwks(&self, stream: &mut TcpStream) {
+ let jwks = self.oauth_server.get_jwks();
+ self.send_json_response(stream, 200, "OK", &jwks);
+ }
+
+ fn handle_authorize(&self, stream: &mut TcpStream, params: &HashMap<String, String>) {
+ match self.oauth_server.handle_authorize(params) {
+ Ok(redirect_url) => {
+ let response = format!(
+ "HTTP/1.1 302 Found\r\nLocation: {}\r\nContent-Length: 0\r\n\r\n",
+ redirect_url
+ );
+ let _ = stream.write_all(response.as_bytes());
+ }
+ Err(error_response) => {
+ self.send_json_response(stream, 400, "Bad Request", &error_response);
+ }
+ }
+ }
+
+ fn handle_token(&self, stream: &mut TcpStream, request: &str) {
+ let body = self.extract_body(request);
+ let form_params = self.parse_form_data(&body);
+
+ match self.oauth_server.handle_token(&form_params) {
+ Ok(token_response) => {
+ self.send_json_response(stream, 200, "OK", &token_response);
+ }
+ Err(error_response) => {
+ self.send_json_response(stream, 400, "Bad Request", &error_response);
+ }
+ }
+ }
+
+ fn extract_body(&self, request: &str) -> String {
+ if let Some(pos) = request.find("\r\n\r\n") {
+ request[pos + 4..].to_string()
+ } else {
+ String::new()
+ }
+ }
+
+ fn parse_form_data(&self, body: &str) -> HashMap<String, String> {
+ body.split('&')
+ .filter_map(|pair| {
+ let mut split = pair.splitn(2, '=');
+ if let (Some(key), Some(value)) = (split.next(), split.next()) {
+ Some((
+ urlencoding::decode(key).unwrap_or_default().to_string(),
+ urlencoding::decode(value).unwrap_or_default().to_string(),
+ ))
+ } else {
+ None
+ }
+ })
+ .collect()
+ }
+ }
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct Claims {
+ sub: String,
+ iss: String,
+ aud: String,
+ exp: u64,
+ iat: u64,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ scope: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct TokenResponse {
+ access_token: String,
+ token_type: String,
+ expires_in: u64,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ refresh_token: Option<String>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ scope: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct ErrorResponse {
+ error: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ error_description: Option<String>,
+}
+
+pub struct OAuthServer {
+ config: Config,
+ encoding_key: EncodingKey,
+ decoding_key: DecodingKey,
+ auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>,
+}
+
+#[derive(Debug, Clone)]
+struct AuthCode {
+ client_id: String,
+ redirect_uri: String,
+ scope: Option<String>,
+ expires_at: u64,
+ user_id: String,
+}
- stream.write_all(response.as_bytes()).unwrap();
+impl OAuthServer {
+ pub fn new(config: &Config) -> Self {
+ Self {
+ encoding_key: EncodingKey::from_secret(config.jwt_secret.as_ref()),
+ decoding_key: DecodingKey::from_secret(config.jwt_secret.as_ref()),
+ auth_codes: std::sync::Mutex::new(HashMap::new()),
+ config: config.clone(),
}
}
+
+ fn get_jwks(&self) -> String {
+ // For simplicity, returning empty JWKS. In production, include public key
+ serde_json::json!({
+ "keys": []
+ })
+ .to_string()
+ }
+
+ pub fn handle_authorize(&self, params: &HashMap<String, String>) -> Result<String, String> {
+ // Validate required parameters
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+
+ let redirect_uri = params
+ .get("redirect_uri")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing redirect_uri"))?;
+
+ let response_type = params
+ .get("response_type")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?;
+
+ if response_type != "code" {
+ return Err(self.error_response(
+ "unsupported_response_type",
+ "Only code response type supported",
+ ));
+ }
+
+ // Generate authorization code
+ let code = Uuid::new_v4().to_string();
+ let expires_at = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs()
+ + 600; // 10 minutes
+
+ let auth_code = AuthCode {
+ client_id: client_id.clone(),
+ redirect_uri: redirect_uri.clone(),
+ scope: params.get("scope").cloned(),
+ expires_at,
+ user_id: "test_user".to_string(), // In production, get from authentication
+ };
+
+ {
+ let mut codes = self.auth_codes.lock().unwrap();
+ codes.insert(code.clone(), auth_code);
+ }
+
+ // Build redirect URL with authorization code
+ let mut redirect_url = Url::parse(redirect_uri)
+ .map_err(|_| self.error_response("invalid_request", "Invalid redirect_uri"))?;
+
+ redirect_url.query_pairs_mut().append_pair("code", &code);
+
+ if let Some(state) = params.get("state") {
+ redirect_url.query_pairs_mut().append_pair("state", state);
+ }
+
+ Ok(redirect_url.to_string())
+ }
+
+ fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> {
+ let grant_type = params
+ .get("grant_type")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
+
+ if grant_type != "authorization_code" {
+ return Err(self.error_response(
+ "unsupported_grant_type",
+ "Only authorization_code grant type supported",
+ ));
+ }
+
+ let code = params
+ .get("code")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?;
+
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+
+ // Validate authorization code
+ let auth_code = {
+ let mut codes = self.auth_codes.lock().unwrap();
+ codes.remove(code).ok_or_else(|| {
+ self.error_response("invalid_grant", "Invalid or expired authorization code")
+ })?
+ };
+
+ // Verify client_id matches
+ if auth_code.client_id != *client_id {
+ return Err(self.error_response("invalid_grant", "Client ID mismatch"));
+ }
+
+ // Check expiration
+ let now = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ if now > auth_code.expires_at {
+ return Err(self.error_response("invalid_grant", "Authorization code expired"));
+ }
+
+ // Generate access token
+ let access_token =
+ self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?;
+
+ let token_response = TokenResponse {
+ access_token,
+ token_type: "Bearer".to_string(),
+ expires_in: 3600, // 1 hour
+ refresh_token: None,
+ scope: auth_code.scope,
+ };
+
+ serde_json::to_string(&token_response)
+ .map_err(|_| self.error_response("server_error", "Failed to serialize token response"))
+ }
+
+ fn generate_access_token(
+ &self,
+ user_id: &str,
+ client_id: &str,
+ scope: &Option<String>,
+ ) -> Result<String, String> {
+ let now = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let claims = Claims {
+ sub: user_id.to_string(),
+ iss: self.config.issuer_url.clone(),
+ aud: client_id.to_string(),
+ exp: now + 3600, // 1 hour
+ iat: now,
+ scope: scope.clone(),
+ };
+
+ encode(&Header::default(), &claims, &self.encoding_key)
+ .map_err(|_| self.error_response("server_error", "Failed to generate token"))
+ }
+
+ fn error_response(&self, error: &str, description: &str) -> String {
+ let error_resp = ErrorResponse {
+ error: error.to_string(),
+ error_description: Some(description.to_string()),
+ };
+ serde_json::to_string(&error_resp).unwrap_or_else(|_| "{}".to_string())
+ }
}
diff --git a/src/main.rs b/src/main.rs
index 442bfe2..64f8fa3 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -12,8 +12,58 @@ fn main() {
#[cfg(test)]
mod tests {
+ use super::*;
+ use std::collections::HashMap;
+
#[test]
- fn it_starts_a_server() {
+ fn test_oauth_server_creation() {
+ let server = sts::http::Server::new("127.0.0.1:0".to_string());
+ // If we get here without panicking, the server was created successfully
assert!(true);
}
+
+ #[test]
+ fn test_authorization_code_generation() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config);
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "test_client".to_string());
+ params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+ params.insert("state".to_string(), "test_state".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_ok());
+
+ let redirect_url = result.unwrap();
+ assert!(redirect_url.contains("code="));
+ assert!(redirect_url.contains("state=test_state"));
+ }
+
+ #[test]
+ fn test_missing_client_id() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config);
+ let mut params = HashMap::new();
+ params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("invalid_request"));
+ }
+
+ #[test]
+ fn test_unsupported_response_type() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config);
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "test_client".to_string());
+ params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert("response_type".to_string(), "token".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("unsupported_response_type"));
+ }
}