diff options
Diffstat (limited to 'src/oauth/server.rs')
| -rw-r--r-- | src/oauth/server.rs | 53 |
1 files changed, 46 insertions, 7 deletions
diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 888b0c2..3ee0bb3 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -1,3 +1,4 @@ +use crate::clients::{ClientManager, parse_basic_auth}; use crate::config::Config; use crate::keys::KeyManager; use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse}; @@ -11,15 +12,18 @@ pub struct OAuthServer { config: Config, key_manager: std::sync::Mutex<KeyManager>, auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>, + client_manager: std::sync::Mutex<ClientManager>, } impl OAuthServer { pub fn new(config: &Config) -> Result<Self, Box<dyn std::error::Error>> { let key_manager = KeyManager::new()?; + let client_manager = ClientManager::new(); Ok(Self { key_manager: std::sync::Mutex::new(key_manager), auth_codes: std::sync::Mutex::new(HashMap::new()), + client_manager: std::sync::Mutex::new(client_manager), config: config.clone(), }) } @@ -45,6 +49,22 @@ impl OAuthServer { .get("response_type") .ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?; + // Validate client exists + let client_manager = self.client_manager.lock().unwrap(); + let _client = client_manager.get_client(client_id) + .ok_or_else(|| self.error_response("invalid_client", "Invalid client_id"))?; + + // Validate redirect URI is registered for this client + if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) { + return Err(self.error_response("invalid_request", "Invalid redirect_uri")); + } + + // Validate requested scopes + let scope = params.get("scope").cloned(); + if !client_manager.is_scope_valid(client_id, &scope) { + return Err(self.error_response("invalid_scope", "Invalid scope")); + } + if response_type != "code" { return Err(self.error_response( "unsupported_response_type", @@ -62,7 +82,7 @@ impl OAuthServer { let auth_code = AuthCode { client_id: client_id.clone(), redirect_uri: redirect_uri.clone(), - scope: params.get("scope").cloned(), + scope: scope, expires_at, user_id: "test_user".to_string(), }; @@ -84,7 +104,7 @@ impl OAuthServer { Ok(redirect_url.to_string()) } - pub fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> { + pub fn handle_token(&self, params: &HashMap<String, String>, auth_header: Option<&str>) -> Result<String, String> { let grant_type = params .get("grant_type") .ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?; @@ -100,9 +120,28 @@ impl OAuthServer { .get("code") .ok_or_else(|| self.error_response("invalid_request", "Missing code"))?; - let client_id = params - .get("client_id") - .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + // Client authentication - RFC 6749 Section 3.2.1 + // Clients can authenticate via HTTP Basic Auth or form parameters + let (client_id, client_secret) = if let Some(auth_header) = auth_header { + // HTTP Basic Authentication (preferred method) + parse_basic_auth(auth_header) + .ok_or_else(|| self.error_response("invalid_client", "Invalid Authorization header"))? + } else { + // Form-based authentication (fallback) + let client_id = params + .get("client_id") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?; + let client_secret = params + .get("client_secret") + .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?; + (client_id.clone(), client_secret.clone()) + }; + + // Authenticate the client + let client_manager = self.client_manager.lock().unwrap(); + if !client_manager.authenticate_client(&client_id, &client_secret) { + return Err(self.error_response("invalid_client", "Client authentication failed")); + } let auth_code = { let mut codes = self.auth_codes.lock().unwrap(); @@ -111,7 +150,7 @@ impl OAuthServer { })? }; - if auth_code.client_id != *client_id { + if auth_code.client_id != client_id { return Err(self.error_response("invalid_grant", "Client ID mismatch")); } @@ -125,7 +164,7 @@ impl OAuthServer { } let access_token = - self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?; + self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope)?; let token_response = TokenResponse { access_token, |
