diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-11 15:48:45 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-11 15:48:45 -0600 |
| commit | dbd3c780f27bd5bee23adf6e280b84d669230e0d (patch) | |
| tree | 21969009f17f8d58e35cf9a73a3ed77eb0e3faca | |
| parent | aea6bd6ec7d7e70a67723edf6327df4a9cc65d89 (diff) | |
test: fix commented out tests
| -rw-r--r-- | src/main.rs | 79 | ||||
| -rw-r--r-- | src/migrations.rs | 26 |
2 files changed, 63 insertions, 42 deletions
diff --git a/src/main.rs b/src/main.rs index 0612bfa..ac47a5e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,24 +34,30 @@ fn main() { server.start(); } -/* #[cfg(test)] -mod disabled_tests { - use std::collections::HashMap; +mod tests { use base64::Engine; + use std::collections::HashMap; - #[test] - fn test_oauth_server_creation() { + fn setup_test_environment() -> sts::Config { let mut config = sts::Config::from_env(); + config.database_path = ":memory:".to_string(); // Use in-memory database for tests config.bind_addr = "127.0.0.1:0".to_string(); config.issuer_url = format!("http://{}", config.bind_addr); + + config + } + + #[test] + fn test_oauth_server_creation() { + let config = setup_test_environment(); let server = sts::http::Server::new(config); assert!(server.is_ok()); } #[test] fn test_authorization_code_generation() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); @@ -62,7 +68,7 @@ mod disabled_tests { params.insert("response_type".to_string(), "code".to_string()); params.insert("state".to_string(), "test_state".to_string()); - let result = oauth_server.handle_authorize(¶ms); + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_ok()); let redirect_url = result.unwrap(); @@ -72,7 +78,7 @@ mod disabled_tests { #[test] fn test_missing_client_id() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert( @@ -81,14 +87,14 @@ mod disabled_tests { ); params.insert("response_type".to_string(), "code".to_string()); - let result = oauth_server.handle_authorize(¶ms); + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_request")); } #[test] fn test_unsupported_response_type() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); params.insert("client_id".to_string(), "test_client".to_string()); @@ -98,14 +104,14 @@ mod disabled_tests { ); params.insert("response_type".to_string(), "token".to_string()); - let result = oauth_server.handle_authorize(¶ms); + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("unsupported_response_type")); } #[test] fn test_jwks_endpoint() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let jwks_json = oauth_server.get_jwks(); @@ -127,7 +133,7 @@ mod disabled_tests { #[test] fn test_full_oauth_flow_with_rsa_tokens() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // Step 1: Authorization request @@ -137,7 +143,7 @@ mod disabled_tests { auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("state".to_string(), "test_state".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); // Extract the authorization code from redirect URL let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); @@ -154,7 +160,7 @@ mod disabled_tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed"); + let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); // Parse token response let token_response: serde_json::Value = serde_json::from_str(&token_result) @@ -175,7 +181,7 @@ mod disabled_tests { #[test] fn test_token_validation_with_jwks() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // Generate a token @@ -184,7 +190,7 @@ mod disabled_tests { auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -198,7 +204,7 @@ mod disabled_tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed"); + let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); let token_response: serde_json::Value = serde_json::from_str(&token_result) .expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); @@ -225,7 +231,7 @@ mod disabled_tests { #[test] fn test_token_contains_proper_claims() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // Generate a token through the full flow @@ -235,7 +241,7 @@ mod disabled_tests { auth_params.insert("response_type".to_string(), "code".to_string()); auth_params.insert("scope".to_string(), "openid profile".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -249,7 +255,7 @@ mod disabled_tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "test_secret".to_string()); - let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed"); + let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed"); let token_response: serde_json::Value = serde_json::from_str(&token_result) .expect("Invalid token response JSON"); let access_token = token_response["access_token"].as_str().unwrap(); @@ -282,7 +288,7 @@ mod disabled_tests { #[test] fn test_invalid_client_id_authorization() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); @@ -290,14 +296,14 @@ mod disabled_tests { params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); params.insert("response_type".to_string(), "code".to_string()); - let result = oauth_server.handle_authorize(¶ms); + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_client")); } #[test] fn test_invalid_redirect_uri() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); @@ -305,14 +311,14 @@ mod disabled_tests { params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string()); params.insert("response_type".to_string(), "code".to_string()); - let result = oauth_server.handle_authorize(¶ms); + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("Invalid redirect_uri")); } #[test] fn test_invalid_scope() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); let mut params = HashMap::new(); @@ -321,14 +327,14 @@ mod disabled_tests { params.insert("response_type".to_string(), "code".to_string()); params.insert("scope".to_string(), "invalid_scope".to_string()); - let result = oauth_server.handle_authorize(¶ms); + let result = oauth_server.handle_authorize(¶ms, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_scope")); } #[test] fn test_invalid_client_secret() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // First get an authorization code @@ -337,7 +343,7 @@ mod disabled_tests { auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -352,14 +358,14 @@ mod disabled_tests { token_params.insert("client_id".to_string(), "test_client".to_string()); token_params.insert("client_secret".to_string(), "wrong_secret".to_string()); - let result = oauth_server.handle_token(&token_params, None); + let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("invalid_client")); } #[test] fn test_missing_client_secret() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // First get an authorization code @@ -368,7 +374,7 @@ mod disabled_tests { auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -383,14 +389,14 @@ mod disabled_tests { token_params.insert("client_id".to_string(), "test_client".to_string()); // Missing client_secret - let result = oauth_server.handle_token(&token_params, None); + let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())); assert!(result.is_err()); assert!(result.unwrap_err().contains("Missing client_secret")); } #[test] fn test_http_basic_auth() { - let config = sts::Config::from_env(); + let config = setup_test_environment(); let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server"); // First get an authorization code @@ -399,7 +405,7 @@ mod disabled_tests { auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string()); auth_params.insert("response_type".to_string(), "code".to_string()); - let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed"); + let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed"); let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL"); let auth_code = redirect_url .query_pairs() @@ -416,8 +422,7 @@ mod disabled_tests { // test_client:test_secret encoded in base64 let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ="; - let result = oauth_server.handle_token(&token_params, Some(auth_header)); + let result = oauth_server.handle_token(&token_params, Some(auth_header), Some("127.0.0.1".to_string())); assert!(result.is_ok()); } } -*/ diff --git a/src/migrations.rs b/src/migrations.rs index 5076a9e..61c5b19 100644 --- a/src/migrations.rs +++ b/src/migrations.rs @@ -62,12 +62,28 @@ impl<'a> MigrationRunner<'a> { } fn get_current_version(&self) -> Result<i32> { - let version = self.conn.query_row( - "SELECT COALESCE(MAX(version), 0) FROM schema_migrations", + // Check if schema_migrations table exists first + let table_exists = self.conn.query_row( + "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'", [], - |row| row.get::<_, i32>(0), - )?; - Ok(version) + |_| Ok(()), + ); + + match table_exists { + Ok(_) => { + // Table exists, get the current version + let version = self.conn.query_row( + "SELECT COALESCE(MAX(version), 0) FROM schema_migrations", + [], + |row| row.get::<_, i32>(0), + )?; + Ok(version) + } + Err(_) => { + // Table doesn't exist, we're at version 0 + Ok(0) + } + } } fn run_migration(&self, migration: &Migration) -> Result<()> { |
