diff options
Diffstat (limited to 'vendor/base64/src/engine/general_purpose/decode.rs')
| -rw-r--r-- | vendor/base64/src/engine/general_purpose/decode.rs | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/vendor/base64/src/engine/general_purpose/decode.rs b/vendor/base64/src/engine/general_purpose/decode.rs new file mode 100644 index 00000000..b55d3fc5 --- /dev/null +++ b/vendor/base64/src/engine/general_purpose/decode.rs @@ -0,0 +1,357 @@ +use crate::{ + engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode}, + DecodeError, DecodeSliceError, PAD_BYTE, +}; + +#[doc(hidden)] +pub struct GeneralPurposeEstimate { + /// input len % 4 + rem: usize, + conservative_decoded_len: usize, +} + +impl GeneralPurposeEstimate { + pub(crate) fn new(encoded_len: usize) -> Self { + let rem = encoded_len % 4; + Self { + rem, + conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3, + } + } +} + +impl DecodeEstimate for GeneralPurposeEstimate { + fn decoded_len_estimate(&self) -> usize { + self.conservative_decoded_len + } +} + +/// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs. +/// Returns the decode metadata, or an error. +// We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is +// inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment, +// but this is fragile and the best setting changes with only minor code modifications. +#[inline] +pub(crate) fn decode_helper( + input: &[u8], + estimate: GeneralPurposeEstimate, + output: &mut [u8], + decode_table: &[u8; 256], + decode_allow_trailing_bits: bool, + padding_mode: DecodePaddingMode, +) -> Result<DecodeMetadata, DecodeSliceError> { + let input_complete_nonterminal_quads_len = + complete_quads_len(input, estimate.rem, output.len(), decode_table)?; + + const UNROLLED_INPUT_CHUNK_SIZE: usize = 32; + const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3; + + let input_complete_quads_after_unrolled_chunks_len = + input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE; + + let input_unrolled_loop_len = + input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len; + + // chunks of 32 bytes + for (chunk_index, chunk) in input[..input_unrolled_loop_len] + .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE) + .enumerate() + { + let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE; + let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE + ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE]; + + decode_chunk_8( + &chunk[0..8], + input_index, + decode_table, + &mut chunk_output[0..6], + )?; + decode_chunk_8( + &chunk[8..16], + input_index + 8, + decode_table, + &mut chunk_output[6..12], + )?; + decode_chunk_8( + &chunk[16..24], + input_index + 16, + decode_table, + &mut chunk_output[12..18], + )?; + decode_chunk_8( + &chunk[24..32], + input_index + 24, + decode_table, + &mut chunk_output[18..24], + )?; + } + + // remaining quads, except for the last possibly partial one, as it may have padding + let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3; + let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3; + { + let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len]; + + for (chunk_index, chunk) in input + [input_unrolled_loop_len..input_complete_nonterminal_quads_len] + .chunks_exact(4) + .enumerate() + { + let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3]; + + decode_chunk_4( + chunk, + input_unrolled_loop_len + chunk_index * 4, + decode_table, + chunk_output, + )?; + } + } + + super::decode_suffix::decode_suffix( + input, + input_complete_nonterminal_quads_len, + output, + output_complete_quad_len, + decode_table, + decode_allow_trailing_bits, + padding_mode, + ) +} + +/// Returns the length of complete quads, except for the last one, even if it is complete. +/// +/// Returns an error if the output len is not big enough for decoding those complete quads, or if +/// the input % 4 == 1, and that last byte is an invalid value other than a pad byte. +/// +/// - `input` is the base64 input +/// - `input_len_rem` is input len % 4 +/// - `output_len` is the length of the output slice +pub(crate) fn complete_quads_len( + input: &[u8], + input_len_rem: usize, + output_len: usize, + decode_table: &[u8; 256], +) -> Result<usize, DecodeSliceError> { + debug_assert!(input.len() % 4 == input_len_rem); + + // detect a trailing invalid byte, like a newline, as a user convenience + if input_len_rem == 1 { + let last_byte = input[input.len() - 1]; + // exclude pad bytes; might be part of padding that extends from earlier in the input + if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE { + return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into()); + } + }; + + // skip last quad, even if it's complete, as it may have padding + let input_complete_nonterminal_quads_len = input + .len() + .saturating_sub(input_len_rem) + // if rem was 0, subtract 4 to avoid padding + .saturating_sub((input_len_rem == 0) as usize * 4); + debug_assert!( + input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len)) + ); + + // check that everything except the last quad handled by decode_suffix will fit + if output_len < input_complete_nonterminal_quads_len / 4 * 3 { + return Err(DecodeSliceError::OutputSliceTooSmall); + }; + Ok(input_complete_nonterminal_quads_len) +} + +/// Decode 8 bytes of input into 6 bytes of output. +/// +/// `input` is the 8 bytes to decode. +/// `index_at_start_of_input` is the offset in the overall input (used for reporting errors +/// accurately) +/// `decode_table` is the lookup table for the particular base64 alphabet. +/// `output` will have its first 6 bytes overwritten +// yes, really inline (worth 30-50% speedup) +#[inline(always)] +fn decode_chunk_8( + input: &[u8], + index_at_start_of_input: usize, + decode_table: &[u8; 256], + output: &mut [u8], +) -> Result<(), DecodeError> { + let morsel = decode_table[usize::from(input[0])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = u64::from(morsel) << 58; + + let morsel = decode_table[usize::from(input[1])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= u64::from(morsel) << 52; + + let morsel = decode_table[usize::from(input[2])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= u64::from(morsel) << 46; + + let morsel = decode_table[usize::from(input[3])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= u64::from(morsel) << 40; + + let morsel = decode_table[usize::from(input[4])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 4, + input[4], + )); + } + accum |= u64::from(morsel) << 34; + + let morsel = decode_table[usize::from(input[5])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 5, + input[5], + )); + } + accum |= u64::from(morsel) << 28; + + let morsel = decode_table[usize::from(input[6])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 6, + input[6], + )); + } + accum |= u64::from(morsel) << 22; + + let morsel = decode_table[usize::from(input[7])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 7, + input[7], + )); + } + accum |= u64::from(morsel) << 16; + + output[..6].copy_from_slice(&accum.to_be_bytes()[..6]); + + Ok(()) +} + +/// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output. +#[inline(always)] +fn decode_chunk_4( + input: &[u8], + index_at_start_of_input: usize, + decode_table: &[u8; 256], + output: &mut [u8], +) -> Result<(), DecodeError> { + let morsel = decode_table[usize::from(input[0])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0])); + } + let mut accum = u32::from(morsel) << 26; + + let morsel = decode_table[usize::from(input[1])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 1, + input[1], + )); + } + accum |= u32::from(morsel) << 20; + + let morsel = decode_table[usize::from(input[2])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 2, + input[2], + )); + } + accum |= u32::from(morsel) << 14; + + let morsel = decode_table[usize::from(input[3])]; + if morsel == INVALID_VALUE { + return Err(DecodeError::InvalidByte( + index_at_start_of_input + 3, + input[3], + )); + } + accum |= u32::from(morsel) << 8; + + output[..3].copy_from_slice(&accum.to_be_bytes()[..3]); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + use crate::engine::general_purpose::STANDARD; + + #[test] + fn decode_chunk_8_writes_only_6_bytes() { + let input = b"Zm9vYmFy"; // "foobar" + let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7]; + + decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output); + } + + #[test] + fn decode_chunk_4_writes_only_3_bytes() { + let input = b"Zm9v"; // "foobar" + let mut output = [0_u8, 1, 2, 3]; + + decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap(); + assert_eq!(&vec![b'f', b'o', b'o', 3], &output); + } + + #[test] + fn estimate_short_lengths() { + for (range, decoded_len_estimate) in [ + (0..=0, 0), + (1..=4, 3), + (5..=8, 6), + (9..=12, 9), + (13..=16, 12), + (17..=20, 15), + ] { + for encoded_len in range { + let estimate = GeneralPurposeEstimate::new(encoded_len); + assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate()); + } + } + } + + #[test] + fn estimate_via_u128_inflation() { + // cover both ends of usize + (0..1000) + .chain(usize::MAX - 1000..=usize::MAX) + .for_each(|encoded_len| { + // inflate to 128 bit type to be able to safely use the easy formulas + let len_128 = encoded_len as u128; + + let estimate = GeneralPurposeEstimate::new(encoded_len); + assert_eq!( + (len_128 + 3) / 4 * 3, + estimate.conservative_decoded_len as u128 + ); + }) + } +} |
