diff options
Diffstat (limited to 'src/oauth/pkce.rs')
| -rw-r--r-- | src/oauth/pkce.rs | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs new file mode 100644 index 0000000..c943844 --- /dev/null +++ b/src/oauth/pkce.rs @@ -0,0 +1,156 @@ +use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; +use sha2::{Digest, Sha256}; +use anyhow::{anyhow, Result}; + +#[derive(Debug, Clone, PartialEq)] +pub enum CodeChallengeMethod { + Plain, + S256, +} + +impl CodeChallengeMethod { + pub fn from_str(s: &str) -> Result<Self> { + match s { + "plain" => Ok(CodeChallengeMethod::Plain), + "S256" => Ok(CodeChallengeMethod::S256), + _ => Err(anyhow!("Unsupported code challenge method: {}", s)), + } + } + + pub fn as_str(&self) -> &'static str { + match self { + CodeChallengeMethod::Plain => "plain", + CodeChallengeMethod::S256 => "S256", + } + } +} + +pub fn verify_code_challenge( + code_verifier: &str, + code_challenge: &str, + method: &CodeChallengeMethod, +) -> Result<bool> { + // Validate code verifier format (RFC 7636 Section 4.1) + if code_verifier.len() < 43 || code_verifier.len() > 128 { + return Err(anyhow!("Code verifier length must be between 43 and 128 characters")); + } + + // Code verifier must only contain unreserved characters + if !code_verifier.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' + }) { + return Err(anyhow!("Code verifier contains invalid characters")); + } + + let computed_challenge = match method { + CodeChallengeMethod::Plain => code_verifier.to_string(), + CodeChallengeMethod::S256 => { + let mut hasher = Sha256::new(); + hasher.update(code_verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hasher.finalize()) + } + }; + + Ok(computed_challenge == code_challenge) +} + +pub fn generate_code_verifier() -> String { + use rand::Rng; + let mut rng = rand::thread_rng(); + + // Generate 32 random bytes and encode them + let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect(); + URL_SAFE_NO_PAD.encode(&bytes) +} + +pub fn generate_code_challenge(verifier: &str, method: &CodeChallengeMethod) -> String { + match method { + CodeChallengeMethod::Plain => verifier.to_string(), + CodeChallengeMethod::S256 => { + let mut hasher = Sha256::new(); + hasher.update(verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hasher.finalize()) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_code_challenge_method_from_str() { + assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain); + assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256); + assert!(CodeChallengeMethod::from_str("invalid").is_err()); + } + + #[test] + fn test_code_challenge_method_as_str() { + assert_eq!(CodeChallengeMethod::Plain.as_str(), "plain"); + assert_eq!(CodeChallengeMethod::S256.as_str(), "S256"); + } + + #[test] + fn test_verify_code_challenge_plain() { + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap()); + assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap()); + } + + #[test] + fn test_verify_code_challenge_s256() { + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; + + assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap()); + assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap()); + } + + #[test] + fn test_verify_code_challenge_invalid_verifier() { + // Too short + assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err()); + + // Invalid characters + assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err()); + } + + #[test] + fn test_generate_code_verifier() { + let verifier = generate_code_verifier(); + assert!(verifier.len() >= 43); + assert!(verifier.len() <= 128); + + // Should only contain valid characters + assert!(verifier.chars().all(|c| { + c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' + })); + } + + #[test] + fn test_generate_code_challenge() { + let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; + + let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain); + assert_eq!(plain_challenge, verifier); + + let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256); + assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"); + } + + #[test] + fn test_round_trip() { + let verifier = generate_code_verifier(); + + // Test with S256 + let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256); + assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap()); + + // Test with Plain + let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain); + assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap()); + } +}
\ No newline at end of file |
