summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 15:48:45 -0600
committermo khan <mo@mokhan.ca>2025-06-11 15:48:45 -0600
commitdbd3c780f27bd5bee23adf6e280b84d669230e0d (patch)
tree21969009f17f8d58e35cf9a73a3ed77eb0e3faca
parentaea6bd6ec7d7e70a67723edf6327df4a9cc65d89 (diff)
test: fix commented out tests
-rw-r--r--src/main.rs79
-rw-r--r--src/migrations.rs26
2 files changed, 63 insertions, 42 deletions
diff --git a/src/main.rs b/src/main.rs
index 0612bfa..ac47a5e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -34,24 +34,30 @@ fn main() {
server.start();
}
-/*
#[cfg(test)]
-mod disabled_tests {
- use std::collections::HashMap;
+mod tests {
use base64::Engine;
+ use std::collections::HashMap;
- #[test]
- fn test_oauth_server_creation() {
+ fn setup_test_environment() -> sts::Config {
let mut config = sts::Config::from_env();
+ config.database_path = ":memory:".to_string(); // Use in-memory database for tests
config.bind_addr = "127.0.0.1:0".to_string();
config.issuer_url = format!("http://{}", config.bind_addr);
+
+ config
+ }
+
+ #[test]
+ fn test_oauth_server_creation() {
+ let config = setup_test_environment();
let server = sts::http::Server::new(config);
assert!(server.is_ok());
}
#[test]
fn test_authorization_code_generation() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
let mut params = HashMap::new();
params.insert("client_id".to_string(), "test_client".to_string());
@@ -62,7 +68,7 @@ mod disabled_tests {
params.insert("response_type".to_string(), "code".to_string());
params.insert("state".to_string(), "test_state".to_string());
- let result = oauth_server.handle_authorize(&params);
+ let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
assert!(result.is_ok());
let redirect_url = result.unwrap();
@@ -72,7 +78,7 @@ mod disabled_tests {
#[test]
fn test_missing_client_id() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
let mut params = HashMap::new();
params.insert(
@@ -81,14 +87,14 @@ mod disabled_tests {
);
params.insert("response_type".to_string(), "code".to_string());
- let result = oauth_server.handle_authorize(&params);
+ let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid_request"));
}
#[test]
fn test_unsupported_response_type() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
let mut params = HashMap::new();
params.insert("client_id".to_string(), "test_client".to_string());
@@ -98,14 +104,14 @@ mod disabled_tests {
);
params.insert("response_type".to_string(), "token".to_string());
- let result = oauth_server.handle_authorize(&params);
+ let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().contains("unsupported_response_type"));
}
#[test]
fn test_jwks_endpoint() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
let jwks_json = oauth_server.get_jwks();
@@ -127,7 +133,7 @@ mod disabled_tests {
#[test]
fn test_full_oauth_flow_with_rsa_tokens() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
// Step 1: Authorization request
@@ -137,7 +143,7 @@ mod disabled_tests {
auth_params.insert("response_type".to_string(), "code".to_string());
auth_params.insert("state".to_string(), "test_state".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
// Extract the authorization code from redirect URL
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
@@ -154,7 +160,7 @@ mod disabled_tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
// Parse token response
let token_response: serde_json::Value = serde_json::from_str(&token_result)
@@ -175,7 +181,7 @@ mod disabled_tests {
#[test]
fn test_token_validation_with_jwks() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
// Generate a token
@@ -184,7 +190,7 @@ mod disabled_tests {
auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -198,7 +204,7 @@ mod disabled_tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
@@ -225,7 +231,7 @@ mod disabled_tests {
#[test]
fn test_token_contains_proper_claims() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
// Generate a token through the full flow
@@ -235,7 +241,7 @@ mod disabled_tests {
auth_params.insert("response_type".to_string(), "code".to_string());
auth_params.insert("scope".to_string(), "openid profile".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -249,7 +255,7 @@ mod disabled_tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
@@ -282,7 +288,7 @@ mod disabled_tests {
#[test]
fn test_invalid_client_id_authorization() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
let mut params = HashMap::new();
@@ -290,14 +296,14 @@ mod disabled_tests {
params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
params.insert("response_type".to_string(), "code".to_string());
- let result = oauth_server.handle_authorize(&params);
+ let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid_client"));
}
#[test]
fn test_invalid_redirect_uri() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
let mut params = HashMap::new();
@@ -305,14 +311,14 @@ mod disabled_tests {
params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string());
params.insert("response_type".to_string(), "code".to_string());
- let result = oauth_server.handle_authorize(&params);
+ let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid redirect_uri"));
}
#[test]
fn test_invalid_scope() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
let mut params = HashMap::new();
@@ -321,14 +327,14 @@ mod disabled_tests {
params.insert("response_type".to_string(), "code".to_string());
params.insert("scope".to_string(), "invalid_scope".to_string());
- let result = oauth_server.handle_authorize(&params);
+ let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid_scope"));
}
#[test]
fn test_invalid_client_secret() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
// First get an authorization code
@@ -337,7 +343,7 @@ mod disabled_tests {
auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -352,14 +358,14 @@ mod disabled_tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "wrong_secret".to_string());
- let result = oauth_server.handle_token(&token_params, None);
+ let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid_client"));
}
#[test]
fn test_missing_client_secret() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
// First get an authorization code
@@ -368,7 +374,7 @@ mod disabled_tests {
auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -383,14 +389,14 @@ mod disabled_tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
// Missing client_secret
- let result = oauth_server.handle_token(&token_params, None);
+ let result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string()));
assert!(result.is_err());
assert!(result.unwrap_err().contains("Missing client_secret"));
}
#[test]
fn test_http_basic_auth() {
- let config = sts::Config::from_env();
+ let config = setup_test_environment();
let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
// First get an authorization code
@@ -399,7 +405,7 @@ mod disabled_tests {
auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -416,8 +422,7 @@ mod disabled_tests {
// test_client:test_secret encoded in base64
let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ=";
- let result = oauth_server.handle_token(&token_params, Some(auth_header));
+ let result = oauth_server.handle_token(&token_params, Some(auth_header), Some("127.0.0.1".to_string()));
assert!(result.is_ok());
}
}
-*/
diff --git a/src/migrations.rs b/src/migrations.rs
index 5076a9e..61c5b19 100644
--- a/src/migrations.rs
+++ b/src/migrations.rs
@@ -62,12 +62,28 @@ impl<'a> MigrationRunner<'a> {
}
fn get_current_version(&self) -> Result<i32> {
- let version = self.conn.query_row(
- "SELECT COALESCE(MAX(version), 0) FROM schema_migrations",
+ // Check if schema_migrations table exists first
+ let table_exists = self.conn.query_row(
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='schema_migrations'",
[],
- |row| row.get::<_, i32>(0),
- )?;
- Ok(version)
+ |_| Ok(()),
+ );
+
+ match table_exists {
+ Ok(_) => {
+ // Table exists, get the current version
+ let version = self.conn.query_row(
+ "SELECT COALESCE(MAX(version), 0) FROM schema_migrations",
+ [],
+ |row| row.get::<_, i32>(0),
+ )?;
+ Ok(version)
+ }
+ Err(_) => {
+ // Table doesn't exist, we're at version 0
+ Ok(0)
+ }
+ }
}
fn run_migration(&self, migration: &Migration) -> Result<()> {