summaryrefslogtreecommitdiff
path: root/src/oauth/pkce.rs
diff options
context:
space:
mode:
authormo khan <mo@mokhan.ca>2025-06-11 15:12:59 -0600
committermo khan <mo@mokhan.ca>2025-06-11 15:12:59 -0600
commit4435ee26b79648e92d0f172e42f9e6629e955505 (patch)
tree0720fd07c879a58672fcfcb2e45ed1161430f039 /src/oauth/pkce.rs
parent39c67cfc6c74bf4b26ba455f3adda1241aea35ea (diff)
chore: rustfmt and include Connection: header in responses
Diffstat (limited to 'src/oauth/pkce.rs')
-rw-r--r--src/oauth/pkce.rs57
1 files changed, 37 insertions, 20 deletions
diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs
index c943844..406d364 100644
--- a/src/oauth/pkce.rs
+++ b/src/oauth/pkce.rs
@@ -1,6 +1,6 @@
-use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
-use sha2::{Digest, Sha256};
use anyhow::{anyhow, Result};
+use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
+use sha2::{Digest, Sha256};
#[derive(Debug, Clone, PartialEq)]
pub enum CodeChallengeMethod {
@@ -32,13 +32,16 @@ pub fn verify_code_challenge(
) -> Result<bool> {
// Validate code verifier format (RFC 7636 Section 4.1)
if code_verifier.len() < 43 || code_verifier.len() > 128 {
- return Err(anyhow!("Code verifier length must be between 43 and 128 characters"));
+ return Err(anyhow!(
+ "Code verifier length must be between 43 and 128 characters"
+ ));
}
// Code verifier must only contain unreserved characters
- if !code_verifier.chars().all(|c| {
- c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
- }) {
+ if !code_verifier
+ .chars()
+ .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
+ {
return Err(anyhow!("Code verifier contains invalid characters"));
}
@@ -57,7 +60,7 @@ pub fn verify_code_challenge(
pub fn generate_code_verifier() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
-
+
// Generate 32 random bytes and encode them
let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect();
URL_SAFE_NO_PAD.encode(&bytes)
@@ -80,8 +83,14 @@ mod tests {
#[test]
fn test_code_challenge_method_from_str() {
- assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain);
- assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256);
+ assert_eq!(
+ CodeChallengeMethod::from_str("plain").unwrap(),
+ CodeChallengeMethod::Plain
+ );
+ assert_eq!(
+ CodeChallengeMethod::from_str("S256").unwrap(),
+ CodeChallengeMethod::S256
+ );
assert!(CodeChallengeMethod::from_str("invalid").is_err());
}
@@ -95,7 +104,7 @@ mod tests {
fn test_verify_code_challenge_plain() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
-
+
assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap());
assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap());
}
@@ -104,7 +113,7 @@ mod tests {
fn test_verify_code_challenge_s256() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
-
+
assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap());
assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap());
}
@@ -113,9 +122,14 @@ mod tests {
fn test_verify_code_challenge_invalid_verifier() {
// Too short
assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err());
-
+
// Invalid characters
- assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err());
+ assert!(verify_code_challenge(
+ "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!",
+ "challenge",
+ &CodeChallengeMethod::Plain
+ )
+ .is_err());
}
#[test]
@@ -123,7 +137,7 @@ mod tests {
let verifier = generate_code_verifier();
assert!(verifier.len() >= 43);
assert!(verifier.len() <= 128);
-
+
// Should only contain valid characters
assert!(verifier.chars().all(|c| {
c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
@@ -133,24 +147,27 @@ mod tests {
#[test]
fn test_generate_code_challenge() {
let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
-
+
let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain);
assert_eq!(plain_challenge, verifier);
-
+
let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256);
- assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
+ assert_eq!(
+ s256_challenge,
+ "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
+ );
}
#[test]
fn test_round_trip() {
let verifier = generate_code_verifier();
-
+
// Test with S256
let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256);
assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap());
-
+
// Test with Plain
let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain);
assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap());
}
-} \ No newline at end of file
+}