use anyhow::{Result, anyhow}; use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; use sha2::{Digest, Sha256}; #[derive(Debug, Clone, PartialEq)] pub enum CodeChallengeMethod { Plain, S256, } impl CodeChallengeMethod { pub fn from_str(s: &str) -> Result { match s { "plain" => Ok(CodeChallengeMethod::Plain), "S256" => Ok(CodeChallengeMethod::S256), _ => Err(anyhow!("Unsupported code challenge method: {}", s)), } } pub fn as_str(&self) -> &'static str { match self { CodeChallengeMethod::Plain => "plain", CodeChallengeMethod::S256 => "S256", } } } pub fn verify_code_challenge( code_verifier: &str, code_challenge: &str, method: &CodeChallengeMethod, ) -> Result { // Validate code verifier format (RFC 7636 Section 4.1) if code_verifier.len() < 43 || code_verifier.len() > 128 { return Err(anyhow!( "Code verifier length must be between 43 and 128 characters" )); } // Code verifier must only contain unreserved characters if !code_verifier .chars() .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~') { return Err(anyhow!("Code verifier contains invalid characters")); } let computed_challenge = match method { CodeChallengeMethod::Plain => code_verifier.to_string(), CodeChallengeMethod::S256 => { let mut hasher = Sha256::new(); hasher.update(code_verifier.as_bytes()); URL_SAFE_NO_PAD.encode(hasher.finalize()) } }; Ok(computed_challenge == code_challenge) } pub fn generate_code_verifier() -> String { use rand::Rng; let mut rng = rand::thread_rng(); // Generate 32 random bytes and encode them let bytes: Vec = (0..32).map(|_| rng.r#gen()).collect(); URL_SAFE_NO_PAD.encode(&bytes) } pub fn generate_code_challenge(verifier: &str, method: &CodeChallengeMethod) -> String { match method { CodeChallengeMethod::Plain => verifier.to_string(), CodeChallengeMethod::S256 => { let mut hasher = Sha256::new(); hasher.update(verifier.as_bytes()); URL_SAFE_NO_PAD.encode(hasher.finalize()) } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_code_challenge_method_from_str() { assert_eq!( CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain ); assert_eq!( CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256 ); assert!(CodeChallengeMethod::from_str("invalid").is_err()); } #[test] fn test_code_challenge_method_as_str() { assert_eq!(CodeChallengeMethod::Plain.as_str(), "plain"); assert_eq!(CodeChallengeMethod::S256.as_str(), "S256"); } #[test] fn test_verify_code_challenge_plain() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap()); assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap()); } #[test] fn test_verify_code_challenge_s256() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"; assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap()); assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap()); } #[test] fn test_verify_code_challenge_invalid_verifier() { // Too short assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err()); // Invalid characters assert!( verify_code_challenge( "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain ) .is_err() ); } #[test] fn test_generate_code_verifier() { let verifier = generate_code_verifier(); assert!(verifier.len() >= 43); assert!(verifier.len() <= 128); // Should only contain valid characters assert!(verifier.chars().all(|c| { c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~' })); } #[test] fn test_generate_code_challenge() { let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"; let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain); assert_eq!(plain_challenge, verifier); let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256); assert_eq!( s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" ); } #[test] fn test_round_trip() { let verifier = generate_code_verifier(); // Test with S256 let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256); assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap()); // Test with Plain let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain); assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap()); } }