diff options
Diffstat (limited to 'src/oauth/server.rs')
| -rw-r--r-- | src/oauth/server.rs | 205 |
1 files changed, 183 insertions, 22 deletions
diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 37c3cbc..a951f7c 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -393,6 +393,21 @@ impl OAuthServer { } } + // Validate redirect_uri matches authorization request + if let Some(redirect_uri_param) = params.get("redirect_uri") { + if redirect_uri_param != &auth_code.redirect_uri { + self.audit_log( + "token_redirect_uri_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "redirect_uri mismatch")); + } + } + // Generate tokens let token_id = Uuid::new_v4().to_string(); let access_token = self.generate_access_token( @@ -401,10 +416,8 @@ impl OAuthServer { &auth_code.scope, &token_id, )?; - let refresh_token = - self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; - // Store token in database for revocation/introspection + // Store access token in database for revocation/introspection let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); let db_access_token = DbAccessToken { id: 0, @@ -415,15 +428,26 @@ impl OAuthServer { expires_at: Utc::now() + Duration::hours(1), created_at: Utc::now(), is_revoked: false, - token_hash, + token_hash: token_hash.clone(), }; - { + let access_token_db_id = { let db = self.database.lock().unwrap(); - if let Err(_) = db.create_access_token(&db_access_token) { - return Err(self.error_response("server_error", "Failed to store access token")); + match db.create_access_token(&db_access_token) { + Ok(id) => id, + Err(_) => { + return Err(self.error_response("server_error", "Failed to store access token")); + } } - } + }; + + // Generate and store refresh token + let refresh_token = self.generate_refresh_token( + access_token_db_id, + &client_id, + &auth_code.user_id, + &auth_code.scope + )?; let token_response = TokenResponse { access_token, @@ -452,7 +476,7 @@ impl OAuthServer { auth_header: Option<&str>, ip_address: Option<String>, ) -> Result<String, String> { - let _refresh_token = params + let refresh_token_str = params .get("refresh_token") .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; @@ -471,6 +495,19 @@ impl OAuthServer { (client_id.clone(), client_secret.clone()) }; + // Rate limiting check + if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") { + self.audit_log( + "refresh_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + // Authenticate the client { let mut client_manager = self.client_manager.lock().unwrap(); @@ -487,25 +524,125 @@ impl OAuthServer { } } - // Validate refresh token (implementation would verify token and get user info) - // For now, return a simple refresh token response + // Validate refresh token + let refresh_token_hash = format!("{:x}", Sha256::digest(refresh_token_str.as_bytes())); + let refresh_token = { + let db = self.database.lock().unwrap(); + match db.get_refresh_token(&refresh_token_hash) { + Ok(Some(token)) => token, + Ok(None) => { + self.audit_log( + "refresh_invalid_token", + Some(&client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Invalid refresh token")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + } + }; + + // Validate refresh token hasn't expired and isn't revoked + if refresh_token.is_revoked { + self.audit_log( + "refresh_token_revoked", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Refresh token revoked")); + } + + if Utc::now() > refresh_token.expires_at { + self.audit_log( + "refresh_token_expired", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Refresh token expired")); + } + + if refresh_token.client_id != client_id { + self.audit_log( + "refresh_client_mismatch", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + // Revoke the old refresh token (optional but recommended security practice) + { + let db = self.database.lock().unwrap(); + let _ = db.revoke_refresh_token(&refresh_token_hash); + } + + // Generate new tokens let new_token_id = Uuid::new_v4().to_string(); - let access_token = - self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; - let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?; + let access_token = self.generate_access_token( + &refresh_token.user_id, + &client_id, + &refresh_token.scope, + &new_token_id, + )?; + + // Store new access token + let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); + let db_access_token = DbAccessToken { + id: 0, + token_id: new_token_id.clone(), + client_id: client_id.clone(), + user_id: refresh_token.user_id.clone(), + scope: refresh_token.scope.clone(), + expires_at: Utc::now() + Duration::hours(1), + created_at: Utc::now(), + is_revoked: false, + token_hash: token_hash.clone(), + }; + + let access_token_db_id = { + let db = self.database.lock().unwrap(); + match db.create_access_token(&db_access_token) { + Ok(id) => id, + Err(_) => { + return Err(self.error_response("server_error", "Failed to store access token")); + } + } + }; + + // Generate new refresh token + let new_refresh_token = self.generate_refresh_token( + access_token_db_id, + &client_id, + &refresh_token.user_id, + &refresh_token.scope, + )?; let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, refresh_token: Some(new_refresh_token), - scope: None, + scope: refresh_token.scope.clone(), }; self.audit_log( "refresh_success", Some(&client_id), - Some("test_user"), + Some(&refresh_token.user_id), ip_address.as_deref(), true, None, @@ -675,13 +812,37 @@ impl OAuthServer { fn generate_refresh_token( &self, - _client_id: &str, - _user_id: &str, - _scope: &Option<String>, + access_token_id: i64, + client_id: &str, + user_id: &str, + scope: &Option<String>, ) -> Result<String, String> { - // For now, return a simple UUID-based refresh token - // In production, this should be a proper JWT or encrypted token - Ok(Uuid::new_v4().to_string()) + use crate::database::DbRefreshToken; + + let refresh_token = Uuid::new_v4().to_string(); + let token_hash = format!("{:x}", Sha256::digest(refresh_token.as_bytes())); + + let db_refresh_token = DbRefreshToken { + id: 0, + token_id: Uuid::new_v4().to_string(), + access_token_id, + client_id: client_id.to_string(), + user_id: user_id.to_string(), + scope: scope.clone(), + expires_at: Utc::now() + Duration::days(30), // 30 day expiration for refresh tokens + created_at: Utc::now(), + is_revoked: false, + token_hash: token_hash.clone(), + }; + + { + let db = self.database.lock().unwrap(); + if let Err(_) = db.create_refresh_token(&db_refresh_token) { + return Err(self.error_response("server_error", "Failed to store refresh token")); + } + } + + Ok(refresh_token) } fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { |
