diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-12 09:59:14 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-12 09:59:14 -0600 |
| commit | 86b6f94b65cb3e821bab75439852afabfec78377 (patch) | |
| tree | f5eaac815ce0c4cc0ac1ee8ab25fe6eecb8dc95d /src | |
| parent | c28b7088b6fad045060a52b6e1a2249e876090e3 (diff) | |
Diffstat (limited to 'src')
| -rw-r--r-- | src/database.rs | 121 | ||||
| -rw-r--r-- | src/oauth/server.rs | 205 |
2 files changed, 304 insertions, 22 deletions
diff --git a/src/database.rs b/src/database.rs index 178eee3..0611efd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -342,6 +342,127 @@ impl Database { Ok(()) } + // Refresh Token operations + pub fn create_refresh_token(&self, token: &DbRefreshToken) -> Result<i64> { + let mut stmt = self.conn.prepare( + "INSERT INTO refresh_tokens + (token_id, access_token_id, client_id, user_id, scope, expires_at, created_at, is_revoked, token_hash) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)", + )?; + + let id = stmt.insert(params![ + token.token_id, + token.access_token_id, + token.client_id, + token.user_id, + token.scope, + token.expires_at.to_rfc3339(), + token.created_at.to_rfc3339(), + token.is_revoked, + token.token_hash + ])?; + + Ok(id) + } + + pub fn get_refresh_token(&self, token_hash: &str) -> Result<Option<DbRefreshToken>> { + let mut stmt = self.conn.prepare( + "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at, + created_at, is_revoked, token_hash + FROM refresh_tokens WHERE token_hash = ?1" + )?; + + let token = stmt.query_row([token_hash], |row| { + Ok(DbRefreshToken { + id: row.get(0)?, + token_id: row.get(1)?, + access_token_id: row.get(2)?, + client_id: row.get(3)?, + user_id: row.get(4)?, + scope: row.get(5)?, + expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 6, + "expires_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 7, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + is_revoked: row.get(8)?, + token_hash: row.get(9)?, + }) + }); + + match token { + Ok(token) => Ok(Some(token)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + + pub fn revoke_refresh_token(&self, token_hash: &str) -> Result<()> { + self.conn.execute( + "UPDATE refresh_tokens SET is_revoked = 1 WHERE token_hash = ?1", + [token_hash], + )?; + Ok(()) + } + + pub fn get_refresh_token_by_access_token(&self, access_token_id: i64) -> Result<Option<DbRefreshToken>> { + let mut stmt = self.conn.prepare( + "SELECT id, token_id, access_token_id, client_id, user_id, scope, expires_at, + created_at, is_revoked, token_hash + FROM refresh_tokens WHERE access_token_id = ?1 AND is_revoked = 0" + )?; + + let token = stmt.query_row([access_token_id], |row| { + Ok(DbRefreshToken { + id: row.get(0)?, + token_id: row.get(1)?, + access_token_id: row.get(2)?, + client_id: row.get(3)?, + user_id: row.get(4)?, + scope: row.get(5)?, + expires_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(6)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 6, + "expires_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + created_at: DateTime::parse_from_rfc3339(&row.get::<_, String>(7)?) + .map_err(|_| { + rusqlite::Error::InvalidColumnType( + 7, + "created_at".to_string(), + rusqlite::types::Type::Text, + ) + })? + .with_timezone(&Utc), + is_revoked: row.get(8)?, + token_hash: row.get(9)?, + }) + }); + + match token { + Ok(token) => Ok(Some(token)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(e.into()), + } + } + // RSA Key operations pub fn create_rsa_key(&self, key: &DbRsaKey) -> Result<i64> { let mut stmt = self.conn.prepare( diff --git a/src/oauth/server.rs b/src/oauth/server.rs index 37c3cbc..a951f7c 100644 --- a/src/oauth/server.rs +++ b/src/oauth/server.rs @@ -393,6 +393,21 @@ impl OAuthServer { } } + // Validate redirect_uri matches authorization request + if let Some(redirect_uri_param) = params.get("redirect_uri") { + if redirect_uri_param != &auth_code.redirect_uri { + self.audit_log( + "token_redirect_uri_mismatch", + Some(&client_id), + Some(&auth_code.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "redirect_uri mismatch")); + } + } + // Generate tokens let token_id = Uuid::new_v4().to_string(); let access_token = self.generate_access_token( @@ -401,10 +416,8 @@ impl OAuthServer { &auth_code.scope, &token_id, )?; - let refresh_token = - self.generate_refresh_token(&client_id, &auth_code.user_id, &auth_code.scope)?; - // Store token in database for revocation/introspection + // Store access token in database for revocation/introspection let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); let db_access_token = DbAccessToken { id: 0, @@ -415,15 +428,26 @@ impl OAuthServer { expires_at: Utc::now() + Duration::hours(1), created_at: Utc::now(), is_revoked: false, - token_hash, + token_hash: token_hash.clone(), }; - { + let access_token_db_id = { let db = self.database.lock().unwrap(); - if let Err(_) = db.create_access_token(&db_access_token) { - return Err(self.error_response("server_error", "Failed to store access token")); + match db.create_access_token(&db_access_token) { + Ok(id) => id, + Err(_) => { + return Err(self.error_response("server_error", "Failed to store access token")); + } } - } + }; + + // Generate and store refresh token + let refresh_token = self.generate_refresh_token( + access_token_db_id, + &client_id, + &auth_code.user_id, + &auth_code.scope + )?; let token_response = TokenResponse { access_token, @@ -452,7 +476,7 @@ impl OAuthServer { auth_header: Option<&str>, ip_address: Option<String>, ) -> Result<String, String> { - let _refresh_token = params + let refresh_token_str = params .get("refresh_token") .ok_or_else(|| self.error_response("invalid_request", "Missing refresh_token"))?; @@ -471,6 +495,19 @@ impl OAuthServer { (client_id.clone(), client_secret.clone()) }; + // Rate limiting check + if let Err(e) = self.check_rate_limit(&format!("client:{}", client_id), "/token") { + self.audit_log( + "refresh_rate_limited", + Some(&client_id), + None, + ip_address.as_deref(), + false, + Some(&e.to_string()), + ); + return Err(self.error_response("temporarily_unavailable", "Rate limit exceeded")); + } + // Authenticate the client { let mut client_manager = self.client_manager.lock().unwrap(); @@ -487,25 +524,125 @@ impl OAuthServer { } } - // Validate refresh token (implementation would verify token and get user info) - // For now, return a simple refresh token response + // Validate refresh token + let refresh_token_hash = format!("{:x}", Sha256::digest(refresh_token_str.as_bytes())); + let refresh_token = { + let db = self.database.lock().unwrap(); + match db.get_refresh_token(&refresh_token_hash) { + Ok(Some(token)) => token, + Ok(None) => { + self.audit_log( + "refresh_invalid_token", + Some(&client_id), + None, + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Invalid refresh token")); + } + Err(_) => { + return Err(self.error_response("server_error", "Internal server error")); + } + } + }; + + // Validate refresh token hasn't expired and isn't revoked + if refresh_token.is_revoked { + self.audit_log( + "refresh_token_revoked", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Refresh token revoked")); + } + + if Utc::now() > refresh_token.expires_at { + self.audit_log( + "refresh_token_expired", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Refresh token expired")); + } + + if refresh_token.client_id != client_id { + self.audit_log( + "refresh_client_mismatch", + Some(&client_id), + Some(&refresh_token.user_id), + ip_address.as_deref(), + false, + None, + ); + return Err(self.error_response("invalid_grant", "Client ID mismatch")); + } + + // Revoke the old refresh token (optional but recommended security practice) + { + let db = self.database.lock().unwrap(); + let _ = db.revoke_refresh_token(&refresh_token_hash); + } + + // Generate new tokens let new_token_id = Uuid::new_v4().to_string(); - let access_token = - self.generate_access_token("test_user", &client_id, &None, &new_token_id)?; - let new_refresh_token = self.generate_refresh_token(&client_id, "test_user", &None)?; + let access_token = self.generate_access_token( + &refresh_token.user_id, + &client_id, + &refresh_token.scope, + &new_token_id, + )?; + + // Store new access token + let token_hash = format!("{:x}", Sha256::digest(access_token.as_bytes())); + let db_access_token = DbAccessToken { + id: 0, + token_id: new_token_id.clone(), + client_id: client_id.clone(), + user_id: refresh_token.user_id.clone(), + scope: refresh_token.scope.clone(), + expires_at: Utc::now() + Duration::hours(1), + created_at: Utc::now(), + is_revoked: false, + token_hash: token_hash.clone(), + }; + + let access_token_db_id = { + let db = self.database.lock().unwrap(); + match db.create_access_token(&db_access_token) { + Ok(id) => id, + Err(_) => { + return Err(self.error_response("server_error", "Failed to store access token")); + } + } + }; + + // Generate new refresh token + let new_refresh_token = self.generate_refresh_token( + access_token_db_id, + &client_id, + &refresh_token.user_id, + &refresh_token.scope, + )?; let token_response = TokenResponse { access_token, token_type: "Bearer".to_string(), expires_in: 3600, refresh_token: Some(new_refresh_token), - scope: None, + scope: refresh_token.scope.clone(), }; self.audit_log( "refresh_success", Some(&client_id), - Some("test_user"), + Some(&refresh_token.user_id), ip_address.as_deref(), true, None, @@ -675,13 +812,37 @@ impl OAuthServer { fn generate_refresh_token( &self, - _client_id: &str, - _user_id: &str, - _scope: &Option<String>, + access_token_id: i64, + client_id: &str, + user_id: &str, + scope: &Option<String>, ) -> Result<String, String> { - // For now, return a simple UUID-based refresh token - // In production, this should be a proper JWT or encrypted token - Ok(Uuid::new_v4().to_string()) + use crate::database::DbRefreshToken; + + let refresh_token = Uuid::new_v4().to_string(); + let token_hash = format!("{:x}", Sha256::digest(refresh_token.as_bytes())); + + let db_refresh_token = DbRefreshToken { + id: 0, + token_id: Uuid::new_v4().to_string(), + access_token_id, + client_id: client_id.to_string(), + user_id: user_id.to_string(), + scope: scope.clone(), + expires_at: Utc::now() + Duration::days(30), // 30 day expiration for refresh tokens + created_at: Utc::now(), + is_revoked: false, + token_hash: token_hash.clone(), + }; + + { + let db = self.database.lock().unwrap(); + if let Err(_) = db.create_refresh_token(&db_refresh_token) { + return Err(self.error_response("server_error", "Failed to store refresh token")); + } + } + + Ok(refresh_token) } fn check_rate_limit(&self, identifier: &str, endpoint: &str) -> Result<()> { |
