summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-09 19:39:28 -0600
committermo khan <mo@mokhan.ca>2025-06-09 19:39:28 -0600
commit1b280d0b9be71daf9dffa0f4fa33559eda91946f (patch)
treecc301717deabe901ce82ce469fef6be84c1b2bed
parent50731865b1a22ab250e988aed2ea2bb8a18f9338 (diff)
Extract a registered client type
-rw-r--r--spec/integration/server_spec.rb38
-rw-r--r--src/clients.rs318
-rw-r--r--src/http/mod.rs15
-rw-r--r--src/lib.rs2
-rw-r--r--src/main.rs149
-rw-r--r--src/oauth/server.rs53
6 files changed, 545 insertions, 30 deletions
diff --git a/spec/integration/server_spec.rb b/spec/integration/server_spec.rb
index 6a041fb..229203c 100644
--- a/spec/integration/server_spec.rb
+++ b/spec/integration/server_spec.rb
@@ -40,25 +40,25 @@ RSpec.describe "Server" do
expect(json[:scopes_supported]).to match_array(["openid", "profile", "email"])
end
- pending 'returns optional fields' do
- expect(json[:response_modes_supported]).to eq("")
- expect(json[:jwks_uri]).to eq("#{base_url}/jwks.json")
- expect(json[:registration_endpoint]).to eq("#{base_url}/register")
- expect(json[:token_endpoint_auth_methods_supported]).to match_array(["client_secret_basic"])
- expect(json[:token_endpoint_auth_signing_alg_values_supported]).to match_array(["RS256"])
- expect(json[:service_documentation]).to eq("#{base_url}/service_documentation.html")
- expect(json[:ui_locales_supported]).to match_array(["en-US"])
- expect(json[:op_policy_uri]).to eq("")
- expect(json[:op_tos_uri]).to eq("")
- expect(json[:revocation_endpoint]).to eq("")
- expect(json[:revocation_endpoint_auth_methods_supported]).to eq("")
- expect(json[:revocation_endpoint_auth_signing_alg_values_supported]).to eq("")
- expect(json[:introspection_endpoint]).to eq("")
- expect(json[:introspection_endpoint_auth_methods_supported]).to eq("")
- expect(json[:introspection_endpoint_auth_signing_alg_values_supported]).to eq("")
- expect(json[:code_challenge_methods_supported]).to eq("")
- expect(json[:signed_metadata]).to eq("")
- expect(json[:grant_types_supported]).to match_array(["authorization_code"])
+ describe "optional fields" do
+ pending { expect(json[:response_modes_supported]).to eq("") }
+ pending { expect(json[:jwks_uri]).to eq("#{base_url}/jwks.json") }
+ pending { expect(json[:registration_endpoint]).to eq("#{base_url}/register") }
+ pending { expect(json[:token_endpoint_auth_methods_supported]).to match_array(["client_secret_basic"]) }
+ pending { expect(json[:token_endpoint_auth_signing_alg_values_supported]).to match_array(["RS256"]) }
+ pending { expect(json[:service_documentation]).to eq("#{base_url}/service_documentation.html") }
+ pending { expect(json[:ui_locales_supported]).to match_array(["en-US"]) }
+ pending { expect(json[:op_policy_uri]).to eq("") }
+ pending { expect(json[:op_tos_uri]).to eq("") }
+ pending { expect(json[:revocation_endpoint]).to eq("") }
+ pending { expect(json[:revocation_endpoint_auth_methods_supported]).to eq("") }
+ pending { expect(json[:revocation_endpoint_auth_signing_alg_values_supported]).to eq("") }
+ pending { expect(json[:introspection_endpoint]).to eq("") }
+ pending { expect(json[:introspection_endpoint_auth_methods_supported]).to eq("") }
+ pending { expect(json[:introspection_endpoint_auth_signing_alg_values_supported]).to eq("") }
+ pending { expect(json[:code_challenge_methods_supported]).to eq("") }
+ pending { expect(json[:signed_metadata]).to eq("") }
+ pending { expect(json[:grant_types_supported]).to match_array(["authorization_code"]) }
end
end
diff --git a/src/clients.rs b/src/clients.rs
new file mode 100644
index 0000000..38d7450
--- /dev/null
+++ b/src/clients.rs
@@ -0,0 +1,318 @@
+use base64::Engine;
+use serde::{Deserialize, Serialize};
+use sha2::{Digest, Sha256};
+use std::collections::HashMap;
+use uuid::Uuid;
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct RegisteredClient {
+ pub client_id: String,
+ pub client_secret_hash: String,
+ pub redirect_uris: Vec<String>,
+ pub client_name: String,
+ pub scopes: Vec<String>,
+ pub grant_types: Vec<String>,
+ pub response_types: Vec<String>,
+ pub created_at: u64,
+}
+
+#[derive(Debug, Clone)]
+pub struct ClientCredentials {
+ pub client_id: String,
+ pub client_secret: String,
+}
+
+pub struct ClientManager {
+ clients: HashMap<String, RegisteredClient>,
+}
+
+impl ClientManager {
+ pub fn new() -> Self {
+ let mut manager = Self {
+ clients: HashMap::new(),
+ };
+
+ // Register a default test client for development
+ manager.register_client(
+ "test_client".to_string(),
+ "test_secret".to_string(),
+ vec!["http://localhost:3000/callback".to_string()],
+ "Test Client".to_string(),
+ vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ ).ok();
+
+ manager
+ }
+
+ pub fn register_client(
+ &mut self,
+ client_id: String,
+ client_secret: String,
+ redirect_uris: Vec<String>,
+ client_name: String,
+ scopes: Vec<String>,
+ ) -> Result<RegisteredClient, String> {
+ // Check if client_id already exists
+ if self.clients.contains_key(&client_id) {
+ return Err("Client ID already exists".to_string());
+ }
+
+ // Validate redirect URIs
+ for uri in &redirect_uris {
+ if !Self::is_valid_redirect_uri(uri) {
+ return Err(format!("Invalid redirect URI: {}", uri));
+ }
+ }
+
+ // Hash the client secret
+ let client_secret_hash = Self::hash_secret(&client_secret);
+
+ let client = RegisteredClient {
+ client_id: client_id.clone(),
+ client_secret_hash,
+ redirect_uris,
+ client_name,
+ scopes,
+ grant_types: vec!["authorization_code".to_string()],
+ response_types: vec!["code".to_string()],
+ created_at: std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap()
+ .as_secs(),
+ };
+
+ self.clients.insert(client_id, client.clone());
+ Ok(client)
+ }
+
+ pub fn get_client(&self, client_id: &str) -> Option<&RegisteredClient> {
+ self.clients.get(client_id)
+ }
+
+ pub fn authenticate_client(&self, client_id: &str, client_secret: &str) -> bool {
+ if let Some(client) = self.get_client(client_id) {
+ let provided_hash = Self::hash_secret(client_secret);
+ // Use constant-time comparison to prevent timing attacks
+ self.constant_time_eq(&client.client_secret_hash, &provided_hash)
+ } else {
+ // Still perform hashing even for non-existent clients to prevent timing attacks
+ Self::hash_secret(client_secret);
+ false
+ }
+ }
+
+ pub fn is_redirect_uri_valid(&self, client_id: &str, redirect_uri: &str) -> bool {
+ if let Some(client) = self.get_client(client_id) {
+ client.redirect_uris.contains(&redirect_uri.to_string())
+ } else {
+ false
+ }
+ }
+
+ pub fn is_scope_valid(&self, client_id: &str, requested_scopes: &Option<String>) -> bool {
+ if let Some(client) = self.get_client(client_id) {
+ if let Some(scopes_str) = requested_scopes {
+ let requested: Vec<&str> = scopes_str.split_whitespace().collect();
+ requested.iter().all(|scope| client.scopes.contains(&scope.to_string()))
+ } else {
+ true // No scopes requested is valid
+ }
+ } else {
+ false
+ }
+ }
+
+ fn hash_secret(secret: &str) -> String {
+ let mut hasher = Sha256::new();
+ hasher.update(secret.as_bytes());
+ format!("{:x}", hasher.finalize())
+ }
+
+ fn is_valid_redirect_uri(uri: &str) -> bool {
+ // Basic validation - in production, this should be more comprehensive
+ uri.starts_with("http://") || uri.starts_with("https://")
+ }
+
+ // Constant-time string comparison to prevent timing attacks
+ fn constant_time_eq(&self, a: &str, b: &str) -> bool {
+ if a.len() != b.len() {
+ return false;
+ }
+
+ let mut result = 0u8;
+ for (byte_a, byte_b) in a.bytes().zip(b.bytes()) {
+ result |= byte_a ^ byte_b;
+ }
+ result == 0
+ }
+
+ pub fn generate_client_credentials(&mut self, client_name: String, redirect_uris: Vec<String>) -> Result<ClientCredentials, String> {
+ let client_id = format!("client_{}", Uuid::new_v4().to_string().replace("-", ""));
+ let client_secret = Uuid::new_v4().to_string();
+
+ self.register_client(
+ client_id.clone(),
+ client_secret.clone(),
+ redirect_uris,
+ client_name,
+ vec!["openid".to_string(), "profile".to_string(), "email".to_string()],
+ )?;
+
+ Ok(ClientCredentials {
+ client_id,
+ client_secret,
+ })
+ }
+
+ pub fn list_clients(&self) -> Vec<&RegisteredClient> {
+ self.clients.values().collect()
+ }
+}
+
+// HTTP Basic Auth parsing helper
+pub fn parse_basic_auth(auth_header: &str) -> Option<(String, String)> {
+ if !auth_header.starts_with("Basic ") {
+ return None;
+ }
+
+ let encoded = &auth_header[6..];
+ let decoded = base64::engine::general_purpose::STANDARD
+ .decode(encoded)
+ .ok()?;
+
+ let credentials = String::from_utf8(decoded).ok()?;
+ let mut parts = credentials.splitn(2, ':');
+
+ let username = parts.next()?.to_string();
+ let password = parts.next()?.to_string();
+
+ Some((username, password))
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_client_registration() {
+ let mut manager = ClientManager::new();
+
+ let result = manager.register_client(
+ "new_client".to_string(),
+ "secret123".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "Test App".to_string(),
+ vec!["openid".to_string()],
+ );
+
+ assert!(result.is_ok());
+ let client = result.unwrap();
+ assert_eq!(client.client_id, "new_client");
+ assert_eq!(client.client_name, "Test App");
+ }
+
+ #[test]
+ fn test_duplicate_client_id() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "duplicate".to_string(),
+ "secret1".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "App 1".to_string(),
+ vec!["openid".to_string()],
+ ).unwrap();
+
+ let result = manager.register_client(
+ "duplicate".to_string(),
+ "secret2".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "App 2".to_string(),
+ vec!["openid".to_string()],
+ );
+
+ assert!(result.is_err());
+ }
+
+ #[test]
+ fn test_client_authentication() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "auth_client".to_string(),
+ "correct_secret".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "Auth Test".to_string(),
+ vec!["openid".to_string()],
+ ).unwrap();
+
+ assert!(manager.authenticate_client("auth_client", "correct_secret"));
+ assert!(!manager.authenticate_client("auth_client", "wrong_secret"));
+ assert!(!manager.authenticate_client("nonexistent", "any_secret"));
+ }
+
+ #[test]
+ fn test_redirect_uri_validation() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "uri_client".to_string(),
+ "secret".to_string(),
+ vec!["https://app.com/callback".to_string(), "http://localhost:3000/callback".to_string()],
+ "URI Test".to_string(),
+ vec!["openid".to_string()],
+ ).unwrap();
+
+ assert!(manager.is_redirect_uri_valid("uri_client", "https://app.com/callback"));
+ assert!(manager.is_redirect_uri_valid("uri_client", "http://localhost:3000/callback"));
+ assert!(!manager.is_redirect_uri_valid("uri_client", "https://evil.com/callback"));
+ assert!(!manager.is_redirect_uri_valid("nonexistent", "https://app.com/callback"));
+ }
+
+ #[test]
+ fn test_scope_validation() {
+ let mut manager = ClientManager::new();
+
+ manager.register_client(
+ "scope_client".to_string(),
+ "secret".to_string(),
+ vec!["https://example.com/callback".to_string()],
+ "Scope Test".to_string(),
+ vec!["openid".to_string(), "profile".to_string()],
+ ).unwrap();
+
+ assert!(manager.is_scope_valid("scope_client", &Some("openid".to_string())));
+ assert!(manager.is_scope_valid("scope_client", &Some("openid profile".to_string())));
+ assert!(!manager.is_scope_valid("scope_client", &Some("openid profile email".to_string())));
+ assert!(manager.is_scope_valid("scope_client", &None));
+ }
+
+ #[test]
+ fn test_basic_auth_parsing() {
+ let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; // test_client:test_secret
+ let result = parse_basic_auth(auth_header);
+
+ assert!(result.is_some());
+ let (username, password) = result.unwrap();
+ assert_eq!(username, "test_client");
+ assert_eq!(password, "test_secret");
+ }
+
+ #[test]
+ fn test_generate_client_credentials() {
+ let mut manager = ClientManager::new();
+
+ let result = manager.generate_client_credentials(
+ "Generated App".to_string(),
+ vec!["https://generated.com/callback".to_string()],
+ );
+
+ assert!(result.is_ok());
+ let credentials = result.unwrap();
+ assert!(credentials.client_id.starts_with("client_"));
+ assert!(!credentials.client_secret.is_empty());
+
+ // Verify the client was actually registered
+ assert!(manager.authenticate_client(&credentials.client_id, &credentials.client_secret));
+ }
+} \ No newline at end of file
diff --git a/src/http/mod.rs b/src/http/mod.rs
index 4523d3b..7b1b983 100644
--- a/src/http/mod.rs
+++ b/src/http/mod.rs
@@ -166,8 +166,11 @@ impl Server {
fn handle_token(&self, stream: &mut TcpStream, request: &str) {
let body = self.extract_body(request);
let form_params = self.parse_form_data(&body);
+
+ // Extract Authorization header from request
+ let auth_header = self.extract_auth_header(request);
- match self.oauth_server.handle_token(&form_params) {
+ match self.oauth_server.handle_token(&form_params, auth_header.as_deref()) {
Ok(token_response) => {
self.send_json_response(stream, 200, "OK", &token_response);
}
@@ -200,4 +203,14 @@ impl Server {
})
.collect()
}
+
+ fn extract_auth_header(&self, request: &str) -> Option<String> {
+ let lines: Vec<&str> = request.lines().collect();
+ for line in lines.iter().skip(1) { // Skip the request line
+ if line.to_lowercase().starts_with("authorization:") {
+ return Some(line[14..].trim().to_string()); // Skip "Authorization: "
+ }
+ }
+ None
+ }
} \ No newline at end of file
diff --git a/src/lib.rs b/src/lib.rs
index 1bff27d..4ed4b7d 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,8 +1,10 @@
+pub mod clients;
pub mod config;
pub mod http;
pub mod keys;
pub mod oauth;
+pub use clients::ClientManager;
pub use config::Config;
pub use http::Server;
pub use oauth::OAuthServer;
diff --git a/src/main.rs b/src/main.rs
index 9b6b0ea..6b31f92 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -124,8 +124,9 @@ mod tests {
token_params.insert("grant_type".to_string(), "authorization_code".to_string());
token_params.insert("code".to_string(), auth_code);
token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
// Parse token response
let token_response: serde_json::Value = serde_json::from_str(&token_result)
@@ -167,8 +168,9 @@ mod tests {
token_params.insert("grant_type".to_string(), "authorization_code".to_string());
token_params.insert("code".to_string(), auth_code);
token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
@@ -217,8 +219,9 @@ mod tests {
token_params.insert("grant_type".to_string(), "authorization_code".to_string());
token_params.insert("code".to_string(), auth_code);
token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
@@ -248,4 +251,144 @@ mod tests {
assert_eq!(claims["aud"], "test_client");
assert_eq!(claims["scope"], "openid profile");
}
+
+ #[test]
+ fn test_invalid_client_id_authorization() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "invalid_client".to_string());
+ params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("invalid_client"));
+ }
+
+ #[test]
+ fn test_invalid_redirect_uri() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "test_client".to_string());
+ params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("Invalid redirect_uri"));
+ }
+
+ #[test]
+ fn test_invalid_scope() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "test_client".to_string());
+ params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+ params.insert("scope".to_string(), "invalid_scope".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("invalid_scope"));
+ }
+
+ #[test]
+ fn test_invalid_client_secret() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // First get an authorization code
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ // Try to exchange with wrong client secret
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "wrong_secret".to_string());
+
+ let result = oauth_server.handle_token(&token_params, None);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("invalid_client"));
+ }
+
+ #[test]
+ fn test_missing_client_secret() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // First get an authorization code
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ // Try to exchange without client secret
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ token_params.insert("client_id".to_string(), "test_client".to_string());
+ // Missing client_secret
+
+ let result = oauth_server.handle_token(&token_params, None);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("Missing client_secret"));
+ }
+
+ #[test]
+ fn test_http_basic_auth() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // First get an authorization code
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ // Use HTTP Basic Auth instead of form parameters
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ // No client_id/client_secret in form
+
+ // test_client:test_secret encoded in base64
+ let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ=";
+
+ let result = oauth_server.handle_token(&token_params, Some(auth_header));
+ assert!(result.is_ok());
+ }
}
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 888b0c2..3ee0bb3 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -1,3 +1,4 @@
+use crate::clients::{ClientManager, parse_basic_auth};
use crate::config::Config;
use crate::keys::KeyManager;
use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse};
@@ -11,15 +12,18 @@ pub struct OAuthServer {
config: Config,
key_manager: std::sync::Mutex<KeyManager>,
auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>,
+ client_manager: std::sync::Mutex<ClientManager>,
}
impl OAuthServer {
pub fn new(config: &Config) -> Result<Self, Box<dyn std::error::Error>> {
let key_manager = KeyManager::new()?;
+ let client_manager = ClientManager::new();
Ok(Self {
key_manager: std::sync::Mutex::new(key_manager),
auth_codes: std::sync::Mutex::new(HashMap::new()),
+ client_manager: std::sync::Mutex::new(client_manager),
config: config.clone(),
})
}
@@ -45,6 +49,22 @@ impl OAuthServer {
.get("response_type")
.ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?;
+ // Validate client exists
+ let client_manager = self.client_manager.lock().unwrap();
+ let _client = client_manager.get_client(client_id)
+ .ok_or_else(|| self.error_response("invalid_client", "Invalid client_id"))?;
+
+ // Validate redirect URI is registered for this client
+ if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) {
+ return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
+ }
+
+ // Validate requested scopes
+ let scope = params.get("scope").cloned();
+ if !client_manager.is_scope_valid(client_id, &scope) {
+ return Err(self.error_response("invalid_scope", "Invalid scope"));
+ }
+
if response_type != "code" {
return Err(self.error_response(
"unsupported_response_type",
@@ -62,7 +82,7 @@ impl OAuthServer {
let auth_code = AuthCode {
client_id: client_id.clone(),
redirect_uri: redirect_uri.clone(),
- scope: params.get("scope").cloned(),
+ scope: scope,
expires_at,
user_id: "test_user".to_string(),
};
@@ -84,7 +104,7 @@ impl OAuthServer {
Ok(redirect_url.to_string())
}
- pub fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> {
+ pub fn handle_token(&self, params: &HashMap<String, String>, auth_header: Option<&str>) -> Result<String, String> {
let grant_type = params
.get("grant_type")
.ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
@@ -100,9 +120,28 @@ impl OAuthServer {
.get("code")
.ok_or_else(|| self.error_response("invalid_request", "Missing code"))?;
- let client_id = params
- .get("client_id")
- .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+ // Client authentication - RFC 6749 Section 3.2.1
+ // Clients can authenticate via HTTP Basic Auth or form parameters
+ let (client_id, client_secret) = if let Some(auth_header) = auth_header {
+ // HTTP Basic Authentication (preferred method)
+ parse_basic_auth(auth_header)
+ .ok_or_else(|| self.error_response("invalid_client", "Invalid Authorization header"))?
+ } else {
+ // Form-based authentication (fallback)
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+ let client_secret = params
+ .get("client_secret")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?;
+ (client_id.clone(), client_secret.clone())
+ };
+
+ // Authenticate the client
+ let client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.authenticate_client(&client_id, &client_secret) {
+ return Err(self.error_response("invalid_client", "Client authentication failed"));
+ }
let auth_code = {
let mut codes = self.auth_codes.lock().unwrap();
@@ -111,7 +150,7 @@ impl OAuthServer {
})?
};
- if auth_code.client_id != *client_id {
+ if auth_code.client_id != client_id {
return Err(self.error_response("invalid_grant", "Client ID mismatch"));
}
@@ -125,7 +164,7 @@ impl OAuthServer {
}
let access_token =
- self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?;
+ self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope)?;
let token_response = TokenResponse {
access_token,