summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs134
1 files changed, 100 insertions, 34 deletions
diff --git a/src/main.rs b/src/main.rs
index ac47a5e..4873a1d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,22 +1,36 @@
+use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
-use sts::Config;
+use sts::container::ServiceContainer;
use sts::http::Server;
+use sts::{Config, Database};
fn main() {
let config = Config::from_env();
- let server = Server::new(config.clone()).expect("Failed to create server");
+
+ // Initialize database
+ let database = Database::new(&config.database_path).expect("Failed to initialize database");
+ let database = Arc::new(Mutex::new(database));
+
+ // Initialize service container with dependency injection
+ let container = ServiceContainer::new(config.clone(), database.clone())
+ .expect("Failed to create service container");
+ let container = Arc::new(container);
+
+ let server = Server::new_with_container(config.clone(), container.clone())
+ .expect("Failed to create server");
// Start cleanup task in background
+ let cleanup_container = container.clone();
let cleanup_config = config.clone();
thread::spawn(move || {
loop {
thread::sleep(Duration::from_secs(
cleanup_config.cleanup_interval_hours as u64 * 3600,
));
- // Note: In the current implementation, we don't have direct access to the OAuth server
- // from here to call cleanup_expired_data(). In a production implementation,
- // you'd want to structure this differently or use a background job queue.
+ if let Err(e) = cleanup_container.cleanup_expired_data() {
+ eprintln!("Cleanup task failed: {}", e);
+ }
}
});
@@ -139,11 +153,16 @@ mod tests {
// Step 1: Authorization request
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
auth_params.insert("state".to_string(), "test_state".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
// Extract the authorization code from redirect URL
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
@@ -160,11 +179,13 @@ mod tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
+ let token_result = oauth_server
+ .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
+ .expect("Token request failed");
// Parse token response
- let token_response: serde_json::Value = serde_json::from_str(&token_result)
- .expect("Invalid token response JSON");
+ let token_response: serde_json::Value =
+ serde_json::from_str(&token_result).expect("Invalid token response JSON");
assert_eq!(token_response["token_type"], "Bearer");
assert_eq!(token_response["expires_in"], 3600);
@@ -173,7 +194,8 @@ mod tests {
let access_token = token_response["access_token"].as_str().unwrap();
// Step 3: Verify the JWT token has RSA signature and key ID
- let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
+ let header =
+ jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
assert_eq!(header.alg, jsonwebtoken::Algorithm::RS256);
assert!(header.kid.is_some());
assert!(!header.kid.as_ref().unwrap().is_empty());
@@ -187,10 +209,15 @@ mod tests {
// Generate a token
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -204,9 +231,11 @@ mod tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
- let token_response: serde_json::Value = serde_json::from_str(&token_result)
- .expect("Invalid token response JSON");
+ let token_result = oauth_server
+ .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
+ .expect("Token request failed");
+ let token_response: serde_json::Value =
+ serde_json::from_str(&token_result).expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
// Get the JWKS
@@ -214,7 +243,8 @@ mod tests {
let jwks: serde_json::Value = serde_json::from_str(&jwks_json).expect("Invalid JWKS JSON");
// Decode the token header to get the key ID
- let header = jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
+ let header =
+ jsonwebtoken::decode_header(access_token).expect("Failed to decode JWT header");
let kid = header.kid.as_ref().expect("No key ID in token");
// Find the matching key in JWKS
@@ -237,11 +267,16 @@ mod tests {
// Generate a token through the full flow
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
auth_params.insert("scope".to_string(), "openid profile".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -255,16 +290,18 @@ mod tests {
token_params.insert("client_id".to_string(), "test_client".to_string());
token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params, None, Some("127.0.0.1".to_string())).expect("Token request failed");
- let token_response: serde_json::Value = serde_json::from_str(&token_result)
- .expect("Invalid token response JSON");
+ let token_result = oauth_server
+ .handle_token(&token_params, None, Some("127.0.0.1".to_string()))
+ .expect("Token request failed");
+ let token_response: serde_json::Value =
+ serde_json::from_str(&token_result).expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
// Decode the token without verification to check claims
let _token_data = jsonwebtoken::decode::<serde_json::Value>(
access_token,
&jsonwebtoken::DecodingKey::from_secret(b"dummy"), // We're not validating, just parsing
- &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256)
+ &jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256),
);
// Since we can't validate with a dummy key, we'll just verify the structure
@@ -275,7 +312,8 @@ mod tests {
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.expect("Failed to decode payload");
- let claims: serde_json::Value = serde_json::from_slice(&payload).expect("Invalid claims JSON");
+ let claims: serde_json::Value =
+ serde_json::from_slice(&payload).expect("Invalid claims JSON");
assert!(claims["sub"].is_string());
assert!(claims["iss"].is_string());
@@ -293,7 +331,10 @@ mod tests {
let mut params = HashMap::new();
params.insert("client_id".to_string(), "invalid_client".to_string());
- params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
params.insert("response_type".to_string(), "code".to_string());
let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
@@ -308,7 +349,10 @@ mod tests {
let mut params = HashMap::new();
params.insert("client_id".to_string(), "test_client".to_string());
- params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string());
+ params.insert(
+ "redirect_uri".to_string(),
+ "https://evil.com/callback".to_string(),
+ );
params.insert("response_type".to_string(), "code".to_string());
let result = oauth_server.handle_authorize(&params, Some("127.0.0.1".to_string()));
@@ -323,7 +367,10 @@ mod tests {
let mut params = HashMap::new();
params.insert("client_id".to_string(), "test_client".to_string());
- params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
params.insert("response_type".to_string(), "code".to_string());
params.insert("scope".to_string(), "invalid_scope".to_string());
@@ -340,10 +387,15 @@ mod tests {
// First get an authorization code
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -371,10 +423,15 @@ mod tests {
// First get an authorization code
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -402,10 +459,15 @@ mod tests {
// First get an authorization code
let mut auth_params = HashMap::new();
auth_params.insert("client_id".to_string(), "test_client".to_string());
- auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert(
+ "redirect_uri".to_string(),
+ "http://localhost:3000/callback".to_string(),
+ );
auth_params.insert("response_type".to_string(), "code".to_string());
- let auth_result = oauth_server.handle_authorize(&auth_params, Some("127.0.0.1".to_string())).expect("Authorization failed");
+ let auth_result = oauth_server
+ .handle_authorize(&auth_params, Some("127.0.0.1".to_string()))
+ .expect("Authorization failed");
let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
let auth_code = redirect_url
.query_pairs()
@@ -422,7 +484,11 @@ mod tests {
// test_client:test_secret encoded in base64
let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ=";
- let result = oauth_server.handle_token(&token_params, Some(auth_header), Some("127.0.0.1".to_string()));
+ let result = oauth_server.handle_token(
+ &token_params,
+ Some(auth_header),
+ Some("127.0.0.1".to_string()),
+ );
assert!(result.is_ok());
}
}