diff options
| author | mo khan <mo@mokhan.ca> | 2025-06-09 17:20:08 -0600 |
|---|---|---|
| committer | mo khan <mo@mokhan.ca> | 2025-06-09 17:20:08 -0600 |
| commit | 72c2297eda4c18f75e7d8587773b36f3ac98b309 (patch) | |
| tree | 091054758812dbfa14979fabb7212a100f294e55 /src/keys.rs | |
| parent | 2ef774d4c52b9fb0ae0d1717b7a3568b76bccf3d (diff) | |
refactor: replace single shared with key rsa keys
Diffstat (limited to 'src/keys.rs')
| -rw-r--r-- | src/keys.rs | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/src/keys.rs b/src/keys.rs new file mode 100644 index 0000000..6c25681 --- /dev/null +++ b/src/keys.rs @@ -0,0 +1,154 @@ +use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; +use jsonwebtoken::{DecodingKey, EncodingKey}; +use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey}; +use rsa::traits::PublicKeyParts; +use rsa::{RsaPrivateKey, RsaPublicKey}; +use serde::Serialize; +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; +use uuid::Uuid; + +#[derive(Clone)] +pub struct KeyPair { + pub kid: String, + pub private_key: RsaPrivateKey, + pub public_key: RsaPublicKey, + pub created_at: u64, + pub encoding_key: EncodingKey, + pub decoding_key: DecodingKey, +} + +#[derive(Debug, Serialize)] +pub struct JwkKey { + pub kty: String, + pub use_: String, + pub kid: String, + pub alg: String, + pub n: String, + pub e: String, +} + +#[derive(Debug, Serialize)] +pub struct Jwks { + pub keys: Vec<JwkKey>, +} + +pub struct KeyManager { + keys: HashMap<String, KeyPair>, + current_key_id: Option<String>, + key_rotation_interval: u64, // seconds +} + +impl KeyManager { + pub fn new() -> Result<Self, Box<dyn std::error::Error>> { + let mut manager = Self { + keys: HashMap::new(), + current_key_id: None, + key_rotation_interval: 86400, // 24 hours + }; + + manager.generate_new_key()?; + Ok(manager) + } + + pub fn generate_new_key(&mut self) -> Result<String, Box<dyn std::error::Error>> { + let mut rng = rand::thread_rng(); + let private_key = RsaPrivateKey::new(&mut rng, 2048)?; + let public_key = RsaPublicKey::from(&private_key); + + let kid = Uuid::new_v4().to_string(); + let created_at = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); + + let encoding_key = EncodingKey::from_rsa_pem( + &private_key + .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)? + .as_bytes(), + )?; + let decoding_key = DecodingKey::from_rsa_pem( + &public_key + .to_public_key_pem(rsa::pkcs8::LineEnding::LF)? + .as_bytes(), + )?; + + let key_pair = KeyPair { + kid: kid.clone(), + private_key, + public_key, + created_at, + encoding_key, + decoding_key, + }; + + self.keys.insert(kid.clone(), key_pair); + self.current_key_id = Some(kid.clone()); + + Ok(kid) + } + + pub fn get_current_key(&self) -> Option<&KeyPair> { + self.current_key_id + .as_ref() + .and_then(|kid| self.keys.get(kid)) + } + + pub fn get_key(&self, kid: &str) -> Option<&KeyPair> { + self.keys.get(kid) + } + + pub fn should_rotate(&self) -> bool { + if let Some(current_key) = self.get_current_key() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + now - current_key.created_at > self.key_rotation_interval + } else { + true + } + } + + pub fn rotate_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> { + self.generate_new_key()?; + Ok(()) + } + + pub fn get_jwks(&self) -> Result<Jwks, Box<dyn std::error::Error>> { + let mut keys = Vec::new(); + + for key_pair in self.keys.values() { + let n = URL_SAFE_NO_PAD.encode(&key_pair.public_key.n().to_bytes_be()); + let e = URL_SAFE_NO_PAD.encode(&key_pair.public_key.e().to_bytes_be()); + + keys.push(JwkKey { + kty: "RSA".to_string(), + use_: "sig".to_string(), + kid: key_pair.kid.clone(), + alg: "RS256".to_string(), + n, + e, + }); + } + + Ok(Jwks { keys }) + } + + pub fn cleanup_old_keys(&mut self, max_age: u64) { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + let current_kid = self.current_key_id.clone(); + + self.keys.retain(|kid, key_pair| { + // Always keep the current key + if Some(kid) == current_kid.as_ref() { + return true; + } + + // Keep keys that are not too old + now - key_pair.created_at <= max_age + }); + } +} |
