summaryrefslogtreecommitdiff
path: root/src/oauth/pkce.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/oauth/pkce.rs')
-rw-r--r--src/oauth/pkce.rs156
1 files changed, 156 insertions, 0 deletions
diff --git a/src/oauth/pkce.rs b/src/oauth/pkce.rs
new file mode 100644
index 0000000..c943844
--- /dev/null
+++ b/src/oauth/pkce.rs
@@ -0,0 +1,156 @@
+use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
+use sha2::{Digest, Sha256};
+use anyhow::{anyhow, Result};
+
+#[derive(Debug, Clone, PartialEq)]
+pub enum CodeChallengeMethod {
+ Plain,
+ S256,
+}
+
+impl CodeChallengeMethod {
+ pub fn from_str(s: &str) -> Result<Self> {
+ match s {
+ "plain" => Ok(CodeChallengeMethod::Plain),
+ "S256" => Ok(CodeChallengeMethod::S256),
+ _ => Err(anyhow!("Unsupported code challenge method: {}", s)),
+ }
+ }
+
+ pub fn as_str(&self) -> &'static str {
+ match self {
+ CodeChallengeMethod::Plain => "plain",
+ CodeChallengeMethod::S256 => "S256",
+ }
+ }
+}
+
+pub fn verify_code_challenge(
+ code_verifier: &str,
+ code_challenge: &str,
+ method: &CodeChallengeMethod,
+) -> Result<bool> {
+ // Validate code verifier format (RFC 7636 Section 4.1)
+ if code_verifier.len() < 43 || code_verifier.len() > 128 {
+ return Err(anyhow!("Code verifier length must be between 43 and 128 characters"));
+ }
+
+ // Code verifier must only contain unreserved characters
+ if !code_verifier.chars().all(|c| {
+ c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
+ }) {
+ return Err(anyhow!("Code verifier contains invalid characters"));
+ }
+
+ let computed_challenge = match method {
+ CodeChallengeMethod::Plain => code_verifier.to_string(),
+ CodeChallengeMethod::S256 => {
+ let mut hasher = Sha256::new();
+ hasher.update(code_verifier.as_bytes());
+ URL_SAFE_NO_PAD.encode(hasher.finalize())
+ }
+ };
+
+ Ok(computed_challenge == code_challenge)
+}
+
+pub fn generate_code_verifier() -> String {
+ use rand::Rng;
+ let mut rng = rand::thread_rng();
+
+ // Generate 32 random bytes and encode them
+ let bytes: Vec<u8> = (0..32).map(|_| rng.r#gen()).collect();
+ URL_SAFE_NO_PAD.encode(&bytes)
+}
+
+pub fn generate_code_challenge(verifier: &str, method: &CodeChallengeMethod) -> String {
+ match method {
+ CodeChallengeMethod::Plain => verifier.to_string(),
+ CodeChallengeMethod::S256 => {
+ let mut hasher = Sha256::new();
+ hasher.update(verifier.as_bytes());
+ URL_SAFE_NO_PAD.encode(hasher.finalize())
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_code_challenge_method_from_str() {
+ assert_eq!(CodeChallengeMethod::from_str("plain").unwrap(), CodeChallengeMethod::Plain);
+ assert_eq!(CodeChallengeMethod::from_str("S256").unwrap(), CodeChallengeMethod::S256);
+ assert!(CodeChallengeMethod::from_str("invalid").is_err());
+ }
+
+ #[test]
+ fn test_code_challenge_method_as_str() {
+ assert_eq!(CodeChallengeMethod::Plain.as_str(), "plain");
+ assert_eq!(CodeChallengeMethod::S256.as_str(), "S256");
+ }
+
+ #[test]
+ fn test_verify_code_challenge_plain() {
+ let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+ let challenge = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+
+ assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::Plain).unwrap());
+ assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::Plain).unwrap());
+ }
+
+ #[test]
+ fn test_verify_code_challenge_s256() {
+ let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+ let challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
+
+ assert!(verify_code_challenge(verifier, challenge, &CodeChallengeMethod::S256).unwrap());
+ assert!(!verify_code_challenge(verifier, "wrong", &CodeChallengeMethod::S256).unwrap());
+ }
+
+ #[test]
+ fn test_verify_code_challenge_invalid_verifier() {
+ // Too short
+ assert!(verify_code_challenge("short", "challenge", &CodeChallengeMethod::Plain).is_err());
+
+ // Invalid characters
+ assert!(verify_code_challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjX!", "challenge", &CodeChallengeMethod::Plain).is_err());
+ }
+
+ #[test]
+ fn test_generate_code_verifier() {
+ let verifier = generate_code_verifier();
+ assert!(verifier.len() >= 43);
+ assert!(verifier.len() <= 128);
+
+ // Should only contain valid characters
+ assert!(verifier.chars().all(|c| {
+ c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~'
+ }));
+ }
+
+ #[test]
+ fn test_generate_code_challenge() {
+ let verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+
+ let plain_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::Plain);
+ assert_eq!(plain_challenge, verifier);
+
+ let s256_challenge = generate_code_challenge(verifier, &CodeChallengeMethod::S256);
+ assert_eq!(s256_challenge, "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM");
+ }
+
+ #[test]
+ fn test_round_trip() {
+ let verifier = generate_code_verifier();
+
+ // Test with S256
+ let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::S256);
+ assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::S256).unwrap());
+
+ // Test with Plain
+ let challenge = generate_code_challenge(&verifier, &CodeChallengeMethod::Plain);
+ assert!(verify_code_challenge(&verifier, &challenge, &CodeChallengeMethod::Plain).unwrap());
+ }
+} \ No newline at end of file