summaryrefslogtreecommitdiff
path: root/src/keys.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/keys.rs')
-rw-r--r--src/keys.rs154
1 files changed, 154 insertions, 0 deletions
diff --git a/src/keys.rs b/src/keys.rs
new file mode 100644
index 0000000..6c25681
--- /dev/null
+++ b/src/keys.rs
@@ -0,0 +1,154 @@
+use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
+use jsonwebtoken::{DecodingKey, EncodingKey};
+use rsa::pkcs8::{EncodePrivateKey, EncodePublicKey};
+use rsa::traits::PublicKeyParts;
+use rsa::{RsaPrivateKey, RsaPublicKey};
+use serde::Serialize;
+use std::collections::HashMap;
+use std::time::{SystemTime, UNIX_EPOCH};
+use uuid::Uuid;
+
+#[derive(Clone)]
+pub struct KeyPair {
+ pub kid: String,
+ pub private_key: RsaPrivateKey,
+ pub public_key: RsaPublicKey,
+ pub created_at: u64,
+ pub encoding_key: EncodingKey,
+ pub decoding_key: DecodingKey,
+}
+
+#[derive(Debug, Serialize)]
+pub struct JwkKey {
+ pub kty: String,
+ pub use_: String,
+ pub kid: String,
+ pub alg: String,
+ pub n: String,
+ pub e: String,
+}
+
+#[derive(Debug, Serialize)]
+pub struct Jwks {
+ pub keys: Vec<JwkKey>,
+}
+
+pub struct KeyManager {
+ keys: HashMap<String, KeyPair>,
+ current_key_id: Option<String>,
+ key_rotation_interval: u64, // seconds
+}
+
+impl KeyManager {
+ pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
+ let mut manager = Self {
+ keys: HashMap::new(),
+ current_key_id: None,
+ key_rotation_interval: 86400, // 24 hours
+ };
+
+ manager.generate_new_key()?;
+ Ok(manager)
+ }
+
+ pub fn generate_new_key(&mut self) -> Result<String, Box<dyn std::error::Error>> {
+ let mut rng = rand::thread_rng();
+ let private_key = RsaPrivateKey::new(&mut rng, 2048)?;
+ let public_key = RsaPublicKey::from(&private_key);
+
+ let kid = Uuid::new_v4().to_string();
+ let created_at = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
+
+ let encoding_key = EncodingKey::from_rsa_pem(
+ &private_key
+ .to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)?
+ .as_bytes(),
+ )?;
+ let decoding_key = DecodingKey::from_rsa_pem(
+ &public_key
+ .to_public_key_pem(rsa::pkcs8::LineEnding::LF)?
+ .as_bytes(),
+ )?;
+
+ let key_pair = KeyPair {
+ kid: kid.clone(),
+ private_key,
+ public_key,
+ created_at,
+ encoding_key,
+ decoding_key,
+ };
+
+ self.keys.insert(kid.clone(), key_pair);
+ self.current_key_id = Some(kid.clone());
+
+ Ok(kid)
+ }
+
+ pub fn get_current_key(&self) -> Option<&KeyPair> {
+ self.current_key_id
+ .as_ref()
+ .and_then(|kid| self.keys.get(kid))
+ }
+
+ pub fn get_key(&self, kid: &str) -> Option<&KeyPair> {
+ self.keys.get(kid)
+ }
+
+ pub fn should_rotate(&self) -> bool {
+ if let Some(current_key) = self.get_current_key() {
+ let now = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ now - current_key.created_at > self.key_rotation_interval
+ } else {
+ true
+ }
+ }
+
+ pub fn rotate_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
+ self.generate_new_key()?;
+ Ok(())
+ }
+
+ pub fn get_jwks(&self) -> Result<Jwks, Box<dyn std::error::Error>> {
+ let mut keys = Vec::new();
+
+ for key_pair in self.keys.values() {
+ let n = URL_SAFE_NO_PAD.encode(&key_pair.public_key.n().to_bytes_be());
+ let e = URL_SAFE_NO_PAD.encode(&key_pair.public_key.e().to_bytes_be());
+
+ keys.push(JwkKey {
+ kty: "RSA".to_string(),
+ use_: "sig".to_string(),
+ kid: key_pair.kid.clone(),
+ alg: "RS256".to_string(),
+ n,
+ e,
+ });
+ }
+
+ Ok(Jwks { keys })
+ }
+
+ pub fn cleanup_old_keys(&mut self, max_age: u64) {
+ let now = SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_secs();
+
+ let current_kid = self.current_key_id.clone();
+
+ self.keys.retain(|kid, key_pair| {
+ // Always keep the current key
+ if Some(kid) == current_kid.as_ref() {
+ return true;
+ }
+
+ // Keep keys that are not too old
+ now - key_pair.created_at <= max_age
+ });
+ }
+}