summaryrefslogtreecommitdiff
path: root/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.rs')
-rw-r--r--src/main.rs149
1 files changed, 146 insertions, 3 deletions
diff --git a/src/main.rs b/src/main.rs
index 9b6b0ea..6b31f92 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -124,8 +124,9 @@ mod tests {
token_params.insert("grant_type".to_string(), "authorization_code".to_string());
token_params.insert("code".to_string(), auth_code);
token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
// Parse token response
let token_response: serde_json::Value = serde_json::from_str(&token_result)
@@ -167,8 +168,9 @@ mod tests {
token_params.insert("grant_type".to_string(), "authorization_code".to_string());
token_params.insert("code".to_string(), auth_code);
token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
@@ -217,8 +219,9 @@ mod tests {
token_params.insert("grant_type".to_string(), "authorization_code".to_string());
token_params.insert("code".to_string(), auth_code);
token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "test_secret".to_string());
- let token_result = oauth_server.handle_token(&token_params).expect("Token request failed");
+ let token_result = oauth_server.handle_token(&token_params, None).expect("Token request failed");
let token_response: serde_json::Value = serde_json::from_str(&token_result)
.expect("Invalid token response JSON");
let access_token = token_response["access_token"].as_str().unwrap();
@@ -248,4 +251,144 @@ mod tests {
assert_eq!(claims["aud"], "test_client");
assert_eq!(claims["scope"], "openid profile");
}
+
+ #[test]
+ fn test_invalid_client_id_authorization() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "invalid_client".to_string());
+ params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("invalid_client"));
+ }
+
+ #[test]
+ fn test_invalid_redirect_uri() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "test_client".to_string());
+ params.insert("redirect_uri".to_string(), "https://evil.com/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("Invalid redirect_uri"));
+ }
+
+ #[test]
+ fn test_invalid_scope() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ let mut params = HashMap::new();
+ params.insert("client_id".to_string(), "test_client".to_string());
+ params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ params.insert("response_type".to_string(), "code".to_string());
+ params.insert("scope".to_string(), "invalid_scope".to_string());
+
+ let result = oauth_server.handle_authorize(&params);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("invalid_scope"));
+ }
+
+ #[test]
+ fn test_invalid_client_secret() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // First get an authorization code
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ // Try to exchange with wrong client secret
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ token_params.insert("client_id".to_string(), "test_client".to_string());
+ token_params.insert("client_secret".to_string(), "wrong_secret".to_string());
+
+ let result = oauth_server.handle_token(&token_params, None);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("invalid_client"));
+ }
+
+ #[test]
+ fn test_missing_client_secret() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // First get an authorization code
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ // Try to exchange without client secret
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ token_params.insert("client_id".to_string(), "test_client".to_string());
+ // Missing client_secret
+
+ let result = oauth_server.handle_token(&token_params, None);
+ assert!(result.is_err());
+ assert!(result.unwrap_err().contains("Missing client_secret"));
+ }
+
+ #[test]
+ fn test_http_basic_auth() {
+ let config = sts::Config::from_env();
+ let oauth_server = sts::OAuthServer::new(&config).expect("Failed to create OAuth server");
+
+ // First get an authorization code
+ let mut auth_params = HashMap::new();
+ auth_params.insert("client_id".to_string(), "test_client".to_string());
+ auth_params.insert("redirect_uri".to_string(), "http://localhost:3000/callback".to_string());
+ auth_params.insert("response_type".to_string(), "code".to_string());
+
+ let auth_result = oauth_server.handle_authorize(&auth_params).expect("Authorization failed");
+ let redirect_url = url::Url::parse(&auth_result).expect("Invalid redirect URL");
+ let auth_code = redirect_url
+ .query_pairs()
+ .find(|(key, _)| key == "code")
+ .map(|(_, value)| value.to_string())
+ .expect("No authorization code in redirect");
+
+ // Use HTTP Basic Auth instead of form parameters
+ let mut token_params = HashMap::new();
+ token_params.insert("grant_type".to_string(), "authorization_code".to_string());
+ token_params.insert("code".to_string(), auth_code);
+ // No client_id/client_secret in form
+
+ // test_client:test_secret encoded in base64
+ let auth_header = "Basic dGVzdF9jbGllbnQ6dGVzdF9zZWNyZXQ=";
+
+ let result = oauth_server.handle_token(&token_params, Some(auth_header));
+ assert!(result.is_ok());
+ }
}