summaryrefslogtreecommitdiff
path: root/src/oauth
diff options
context:
space:
mode:
Diffstat (limited to 'src/oauth')
-rw-r--r--src/oauth/server.rs53
1 files changed, 46 insertions, 7 deletions
diff --git a/src/oauth/server.rs b/src/oauth/server.rs
index 888b0c2..3ee0bb3 100644
--- a/src/oauth/server.rs
+++ b/src/oauth/server.rs
@@ -1,3 +1,4 @@
+use crate::clients::{ClientManager, parse_basic_auth};
use crate::config::Config;
use crate::keys::KeyManager;
use crate::oauth::types::{AuthCode, Claims, ErrorResponse, TokenResponse};
@@ -11,15 +12,18 @@ pub struct OAuthServer {
config: Config,
key_manager: std::sync::Mutex<KeyManager>,
auth_codes: std::sync::Mutex<HashMap<String, AuthCode>>,
+ client_manager: std::sync::Mutex<ClientManager>,
}
impl OAuthServer {
pub fn new(config: &Config) -> Result<Self, Box<dyn std::error::Error>> {
let key_manager = KeyManager::new()?;
+ let client_manager = ClientManager::new();
Ok(Self {
key_manager: std::sync::Mutex::new(key_manager),
auth_codes: std::sync::Mutex::new(HashMap::new()),
+ client_manager: std::sync::Mutex::new(client_manager),
config: config.clone(),
})
}
@@ -45,6 +49,22 @@ impl OAuthServer {
.get("response_type")
.ok_or_else(|| self.error_response("invalid_request", "Missing response_type"))?;
+ // Validate client exists
+ let client_manager = self.client_manager.lock().unwrap();
+ let _client = client_manager.get_client(client_id)
+ .ok_or_else(|| self.error_response("invalid_client", "Invalid client_id"))?;
+
+ // Validate redirect URI is registered for this client
+ if !client_manager.is_redirect_uri_valid(client_id, redirect_uri) {
+ return Err(self.error_response("invalid_request", "Invalid redirect_uri"));
+ }
+
+ // Validate requested scopes
+ let scope = params.get("scope").cloned();
+ if !client_manager.is_scope_valid(client_id, &scope) {
+ return Err(self.error_response("invalid_scope", "Invalid scope"));
+ }
+
if response_type != "code" {
return Err(self.error_response(
"unsupported_response_type",
@@ -62,7 +82,7 @@ impl OAuthServer {
let auth_code = AuthCode {
client_id: client_id.clone(),
redirect_uri: redirect_uri.clone(),
- scope: params.get("scope").cloned(),
+ scope: scope,
expires_at,
user_id: "test_user".to_string(),
};
@@ -84,7 +104,7 @@ impl OAuthServer {
Ok(redirect_url.to_string())
}
- pub fn handle_token(&self, params: &HashMap<String, String>) -> Result<String, String> {
+ pub fn handle_token(&self, params: &HashMap<String, String>, auth_header: Option<&str>) -> Result<String, String> {
let grant_type = params
.get("grant_type")
.ok_or_else(|| self.error_response("invalid_request", "Missing grant_type"))?;
@@ -100,9 +120,28 @@ impl OAuthServer {
.get("code")
.ok_or_else(|| self.error_response("invalid_request", "Missing code"))?;
- let client_id = params
- .get("client_id")
- .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+ // Client authentication - RFC 6749 Section 3.2.1
+ // Clients can authenticate via HTTP Basic Auth or form parameters
+ let (client_id, client_secret) = if let Some(auth_header) = auth_header {
+ // HTTP Basic Authentication (preferred method)
+ parse_basic_auth(auth_header)
+ .ok_or_else(|| self.error_response("invalid_client", "Invalid Authorization header"))?
+ } else {
+ // Form-based authentication (fallback)
+ let client_id = params
+ .get("client_id")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_id"))?;
+ let client_secret = params
+ .get("client_secret")
+ .ok_or_else(|| self.error_response("invalid_request", "Missing client_secret"))?;
+ (client_id.clone(), client_secret.clone())
+ };
+
+ // Authenticate the client
+ let client_manager = self.client_manager.lock().unwrap();
+ if !client_manager.authenticate_client(&client_id, &client_secret) {
+ return Err(self.error_response("invalid_client", "Client authentication failed"));
+ }
let auth_code = {
let mut codes = self.auth_codes.lock().unwrap();
@@ -111,7 +150,7 @@ impl OAuthServer {
})?
};
- if auth_code.client_id != *client_id {
+ if auth_code.client_id != client_id {
return Err(self.error_response("invalid_grant", "Client ID mismatch"));
}
@@ -125,7 +164,7 @@ impl OAuthServer {
}
let access_token =
- self.generate_access_token(&auth_code.user_id, client_id, &auth_code.scope)?;
+ self.generate_access_token(&auth_code.user_id, &client_id, &auth_code.scope)?;
let token_response = TokenResponse {
access_token,