diff options
Diffstat (limited to 'vendor/rustls/src/record_layer.rs')
| -rw-r--r-- | vendor/rustls/src/record_layer.rs | 336 |
1 files changed, 336 insertions, 0 deletions
diff --git a/vendor/rustls/src/record_layer.rs b/vendor/rustls/src/record_layer.rs new file mode 100644 index 00000000..8c2ea6b9 --- /dev/null +++ b/vendor/rustls/src/record_layer.rs @@ -0,0 +1,336 @@ +use alloc::boxed::Box; +use core::cmp::min; + +use crate::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter}; +use crate::error::Error; +use crate::log::trace; +use crate::msgs::message::{InboundPlainMessage, OutboundOpaqueMessage, OutboundPlainMessage}; + +#[derive(PartialEq)] +enum DirectionState { + /// No keying material. + Invalid, + + /// Keying material present, but not yet in use. + Prepared, + + /// Keying material in use. + Active, +} + +/// Record layer that tracks decryption and encryption keys. +pub(crate) struct RecordLayer { + message_encrypter: Box<dyn MessageEncrypter>, + message_decrypter: Box<dyn MessageDecrypter>, + write_seq_max: u64, + write_seq: u64, + read_seq: u64, + has_decrypted: bool, + encrypt_state: DirectionState, + decrypt_state: DirectionState, + + // Message encrypted with other keys may be encountered, so failures + // should be swallowed by the caller. This struct tracks the amount + // of message size this is allowed for. + trial_decryption_len: Option<usize>, +} + +impl RecordLayer { + /// Create new record layer with no keys. + pub(crate) fn new() -> Self { + Self { + message_encrypter: <dyn MessageEncrypter>::invalid(), + message_decrypter: <dyn MessageDecrypter>::invalid(), + write_seq_max: 0, + write_seq: 0, + read_seq: 0, + has_decrypted: false, + encrypt_state: DirectionState::Invalid, + decrypt_state: DirectionState::Invalid, + trial_decryption_len: None, + } + } + + /// Decrypt a TLS message. + /// + /// `encr` is a decoded message allegedly received from the peer. + /// If it can be decrypted, its decryption is returned. Otherwise, + /// an error is returned. + pub(crate) fn decrypt_incoming<'a>( + &mut self, + encr: InboundOpaqueMessage<'a>, + ) -> Result<Option<Decrypted<'a>>, Error> { + if self.decrypt_state != DirectionState::Active { + return Ok(Some(Decrypted { + want_close_before_decrypt: false, + plaintext: encr.into_plain_message(), + })); + } + + // Set to `true` if the peer appears to getting close to encrypting + // too many messages with this key. + // + // Perhaps if we send an alert well before their counter wraps, a + // buggy peer won't make a terrible mistake here? + // + // Note that there's no reason to refuse to decrypt: the security + // failure has already happened. + let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT; + + let encrypted_len = encr.payload.len(); + match self + .message_decrypter + .decrypt(encr, self.read_seq) + { + Ok(plaintext) => { + self.read_seq += 1; + if !self.has_decrypted { + self.has_decrypted = true; + } + Ok(Some(Decrypted { + want_close_before_decrypt, + plaintext, + })) + } + Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => { + trace!("Dropping undecryptable message after aborted early_data"); + Ok(None) + } + Err(err) => Err(err), + } + } + + /// Encrypt a TLS message. + /// + /// `plain` is a TLS message we'd like to send. This function + /// panics if the requisite keying material hasn't been established yet. + pub(crate) fn encrypt_outgoing( + &mut self, + plain: OutboundPlainMessage<'_>, + ) -> OutboundOpaqueMessage { + debug_assert!(self.encrypt_state == DirectionState::Active); + assert!(self.next_pre_encrypt_action() != PreEncryptAction::Refuse); + let seq = self.write_seq; + self.write_seq += 1; + self.message_encrypter + .encrypt(plain, seq) + .unwrap() + } + + /// Prepare to use the given `MessageEncrypter` for future message encryption. + /// It is not used until you call `start_encrypting`. + pub(crate) fn prepare_message_encrypter( + &mut self, + cipher: Box<dyn MessageEncrypter>, + max_messages: u64, + ) { + self.message_encrypter = cipher; + self.write_seq = 0; + self.write_seq_max = min(SEQ_SOFT_LIMIT, max_messages); + self.encrypt_state = DirectionState::Prepared; + } + + /// Prepare to use the given `MessageDecrypter` for future message decryption. + /// It is not used until you call `start_decrypting`. + pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) { + self.message_decrypter = cipher; + self.read_seq = 0; + self.decrypt_state = DirectionState::Prepared; + } + + /// Start using the `MessageEncrypter` previously provided to the previous + /// call to `prepare_message_encrypter`. + pub(crate) fn start_encrypting(&mut self) { + debug_assert!(self.encrypt_state == DirectionState::Prepared); + self.encrypt_state = DirectionState::Active; + } + + /// Start using the `MessageDecrypter` previously provided to the previous + /// call to `prepare_message_decrypter`. + pub(crate) fn start_decrypting(&mut self) { + debug_assert!(self.decrypt_state == DirectionState::Prepared); + self.decrypt_state = DirectionState::Active; + } + + /// Set and start using the given `MessageEncrypter` for future outgoing + /// message encryption. + pub(crate) fn set_message_encrypter( + &mut self, + cipher: Box<dyn MessageEncrypter>, + max_messages: u64, + ) { + self.prepare_message_encrypter(cipher, max_messages); + self.start_encrypting(); + } + + /// Set and start using the given `MessageDecrypter` for future incoming + /// message decryption. + pub(crate) fn set_message_decrypter(&mut self, cipher: Box<dyn MessageDecrypter>) { + self.prepare_message_decrypter(cipher); + self.start_decrypting(); + self.trial_decryption_len = None; + } + + /// Set and start using the given `MessageDecrypter` for future incoming + /// message decryption, and enable "trial decryption" mode for when TLS1.3 + /// 0-RTT is attempted but rejected by the server. + pub(crate) fn set_message_decrypter_with_trial_decryption( + &mut self, + cipher: Box<dyn MessageDecrypter>, + max_length: usize, + ) { + self.prepare_message_decrypter(cipher); + self.start_decrypting(); + self.trial_decryption_len = Some(max_length); + } + + pub(crate) fn finish_trial_decryption(&mut self) { + self.trial_decryption_len = None; + } + + pub(crate) fn next_pre_encrypt_action(&self) -> PreEncryptAction { + self.pre_encrypt_action(0) + } + + /// Return a remedial action when we are near to encrypting too many messages. + /// + /// `add` is added to the current sequence number. `add` as `0` means + /// "the next message processed by `encrypt_outgoing`" + pub(crate) fn pre_encrypt_action(&self, add: u64) -> PreEncryptAction { + match self.write_seq.saturating_add(add) { + v if v == self.write_seq_max => PreEncryptAction::RefreshOrClose, + SEQ_HARD_LIMIT.. => PreEncryptAction::Refuse, + _ => PreEncryptAction::Nothing, + } + } + + pub(crate) fn is_encrypting(&self) -> bool { + self.encrypt_state == DirectionState::Active + } + + /// Return true if we have ever decrypted a message. This is used in place + /// of checking the read_seq since that will be reset on key updates. + pub(crate) fn has_decrypted(&self) -> bool { + self.has_decrypted + } + + pub(crate) fn write_seq(&self) -> u64 { + self.write_seq + } + + pub(crate) fn read_seq(&self) -> u64 { + self.read_seq + } + + pub(crate) fn encrypted_len(&self, payload_len: usize) -> usize { + self.message_encrypter + .encrypted_payload_len(payload_len) + } + + fn doing_trial_decryption(&mut self, requested: usize) -> bool { + match self + .trial_decryption_len + .and_then(|value| value.checked_sub(requested)) + { + Some(remaining) => { + self.trial_decryption_len = Some(remaining); + true + } + _ => false, + } + } +} + +/// Result of decryption. +#[derive(Debug)] +pub(crate) struct Decrypted<'a> { + /// Whether the peer appears to be getting close to encrypting too many messages with this key. + pub(crate) want_close_before_decrypt: bool, + /// The decrypted message. + pub(crate) plaintext: InboundPlainMessage<'a>, +} + +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum PreEncryptAction { + /// No action is needed before calling `encrypt_outgoing` + Nothing, + + /// A `key_update` request should be sent ASAP. + /// + /// If that is not possible (for example, the connection is TLS1.2), a `close_notify` + /// alert should be sent instead. + RefreshOrClose, + + /// Do not call `encrypt_outgoing` further, it will panic rather than + /// over-use the key. + Refuse, +} + +const SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64; +const SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_has_decrypted() { + use crate::{ContentType, ProtocolVersion}; + + struct PassThroughDecrypter; + impl MessageDecrypter for PassThroughDecrypter { + fn decrypt<'a>( + &mut self, + m: InboundOpaqueMessage<'a>, + _: u64, + ) -> Result<InboundPlainMessage<'a>, Error> { + Ok(m.into_plain_message()) + } + } + + // A record layer starts out invalid, having never decrypted. + let mut record_layer = RecordLayer::new(); + assert!(matches!( + record_layer.decrypt_state, + DirectionState::Invalid + )); + assert_eq!(record_layer.read_seq, 0); + assert!(!record_layer.has_decrypted()); + + // Preparing the record layer should update the decrypt state, but shouldn't affect whether it + // has decrypted. + record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter)); + assert!(matches!( + record_layer.decrypt_state, + DirectionState::Prepared + )); + assert_eq!(record_layer.read_seq, 0); + assert!(!record_layer.has_decrypted()); + + // Starting decryption should update the decrypt state, but not affect whether it has decrypted. + record_layer.start_decrypting(); + assert!(matches!(record_layer.decrypt_state, DirectionState::Active)); + assert_eq!(record_layer.read_seq, 0); + assert!(!record_layer.has_decrypted()); + + // Decrypting a message should update the read_seq and track that we have now performed + // a decryption. + record_layer + .decrypt_incoming(InboundOpaqueMessage::new( + ContentType::Handshake, + ProtocolVersion::TLSv1_2, + &mut [0xC0, 0xFF, 0xEE], + )) + .unwrap(); + assert!(matches!(record_layer.decrypt_state, DirectionState::Active)); + assert_eq!(record_layer.read_seq, 1); + assert!(record_layer.has_decrypted()); + + // Resetting the record layer message decrypter (as if a key update occurred) should reset + // the read_seq number, but not our knowledge of whether we have decrypted previously. + record_layer.set_message_decrypter(Box::new(PassThroughDecrypter)); + assert!(matches!(record_layer.decrypt_state, DirectionState::Active)); + assert_eq!(record_layer.read_seq, 0); + assert!(record_layer.has_decrypted()); + } +} |
