diff --git a/Cargo.toml b/Cargo.toml index f4a8785..320fd06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] [dev-dependencies] -criterion = "0.3" +criterion = { version = "0.4", features = [ "real_blackbox" ] } rand = "0.8.5" [[bench]] diff --git a/benches/bench.rs b/benches/bench.rs index 480c935..5024621 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,6 +1,5 @@ #![feature(generic_const_exprs)] #![feature(new_uninit)] - #![feature(portable_simd)] use std::mem::MaybeUninit; @@ -15,15 +14,21 @@ const ASCII_BYTES_LONG: &[u8; 256] = b"Donald J. Trump!Donald J. Trump!Donald J. const HEX_BYTES_LONG: &[u8; ASCII_BYTES_LONG.len() * 2] = b"446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021"; macro_rules! name { - ($group:ident, $f:literal) => { - std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str()) - }; ($group:literal, $f:literal) => { concat!("[", $group, "] - ", $f) }; + ($group:expr, $f:literal) => { + std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str()) + }; + ($group:expr, $f:expr) => { + std::boxed::Box::leak(format!(name!("{}", "{}"), $group, $f).into_boxed_str()) + }; + ($group:literal, $f:expr) => { + std::boxed::Box::leak(format!(name!($group, "{}"), $f).into_boxed_str()) + }; } -//#[track_caller] +#[track_caller] fn test_sized(hex_bytes: &[u8; N * 2], bytes: &[u8; N]) where [(); N * 2]:, @@ -50,7 +55,7 @@ where ); } -//#[track_caller] +#[track_caller] #[inline] fn test(hex_bytes: &[u8], bytes: &[u8]) { assert_eq!(hex_bytes.len(), bytes.len() * 2); @@ -261,62 +266,120 @@ pub fn bench_1_6m(c: &mut Criterion) { benchmark_sized::("1.6m", &hex_bytes, c); } -pub fn bench_micro_hex_byte(c: &mut Criterion) { +pub fn bench_micro_hex_digit(c: &mut Criterion) { use std::simd::Simd; - const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [ - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", + const HEX_DIGITS_VALID: [u8; DIGIT_BATCH_SIZE] = [ + 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4, 0xf, 0xf, + 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, + 0x4, + // 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4, - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", - - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", - - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", + // 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4, ]; - let hex_bytes = Simd::from_array(conv::u8x2_to_u8(black_box(HEX_BYTES_VALID))); + let hex_digits = Simd::from_array(black_box(HEX_DIGITS_VALID)); + + c.bench_function(name!("micro", "hex_digit"), |b| { + b.iter(|| { + for b in black_box(HEX_DIGITS_VALID) { + black_box(hex_digit(b)); + } + }) + }); + + c.bench_function(name!("micro", "hex_digit_simd"), |b| { + b.iter(|| hex_digit_simd::(hex_digits)) + }); +} + +pub fn bench_micro_hex_byte(c: &mut Criterion) { + const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [ + *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11", + *b"ef", *b"fe", *b"0f", *b"f0", + *b"34", + // *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", + + // *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", + ]; + + fn bench_decoder(c: &mut Criterion, name: &str) { + let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID); + + c.bench_function(name!("micro", format!("{name}::decode_packed")), |b| { + b.iter(|| { + for b in black_box(HEX_BYTES_VALID) { + black_box(T::decode_packed(&b)); + } + }) + }); + + c.bench_function(name!("micro", format!("{name}::decode_unpacked")), |b| { + b.iter(|| { + for [hi, lo] in black_box(HEX_BYTES_VALID) { + black_box(T::decode_unpacked(hi, lo)); + } + }) + }); + + c.bench_function(name!("micro", format!("{name}::decode_simd")), |b| { + b.iter(|| T::decode_simd(black_box(hex_bytes))) + }); + } c.bench_function(name!("micro", "hex_byte"), |b| { b.iter(|| { for b in black_box(HEX_BYTES_VALID) { - hex_byte(b[0], b[1]); + black_box(hex_byte(b[0], b[1])); } }) }); - c.bench_function(name!("micro", "hex_byte_niched"), |b| { - b.iter(|| { - for b in black_box(HEX_BYTES_VALID) { - hex_byte_niched(b[0], b[1]); - } - }) - }); + bench_decoder::(c, stringify!(HexByteDecoderA)); + bench_decoder::(c, stringify!(HexByteDecoderB)); +} - c.bench_function(name!("micro", "hex_byte_simd"), |b| { - b.iter(|| hex_byte_simd(hex_bytes)) +pub fn bench_nano_hex_digit(c: &mut Criterion) { + let digit = black_box('5' as u8); + c.bench_function(name!("nano", "hex_digit"), |b| b.iter(|| hex_digit(digit))); + + c.bench_function(name!("nano", "hex_digit +bb"), |b| { + b.iter(|| hex_digit(black_box(digit))) }); } pub fn bench_nano_hex_byte(c: &mut Criterion) { - let digit = black_box(['5' as u8, 'b' as u8]); + const DIGITS: [u8; 2] = ['5' as u8, 'b' as u8]; + let digit = black_box(DIGITS); + c.bench_function(name!("nano", "hex_byte"), |b| { b.iter(|| hex_byte(digit[0], digit[1])) }); - c.bench_function(name!("nano", "hex_byte_niched"), |b| { - b.iter(|| hex_byte_niched(digit[0], digit[1])) - }); + fn bench_decoder(c: &mut Criterion, name: &str) { + let digit = black_box(DIGITS); - c.bench_function(name!("nano", "hex_byte +bb"), |b| { - b.iter(|| hex_byte(black_box(digit[0]), black_box(digit[1]))) - }); + c.bench_function(name!("nano", format!("{name}::decode_packed")), |b| { + b.iter(|| { + black_box(T::decode_packed(&digit)); + }) + }); - c.bench_function(name!("nano", "hex_byte_niched +bb"), |b| { - b.iter(|| hex_byte_niched(black_box(digit[0]), black_box(digit[1]))) - }); + c.bench_function(name!("nano", format!("{name}::decode_unpacked")), |b| { + b.iter(|| { + black_box(T::decode_unpacked( + black_box(DIGITS[0]), + black_box(DIGITS[1]), + )); + }) + }); + } + + bench_decoder::(c, stringify!(HexByteDecoderA)); + bench_decoder::(c, stringify!(HexByteDecoderB)); } criterion_group!(decode_benches, bench_16, bench_256, bench_1_6m); -criterion_group!(micro_benches, bench_micro_hex_byte); -criterion_group!(nano_benches, bench_nano_hex_byte); +criterion_group!(micro_benches, bench_micro_hex_digit, bench_micro_hex_byte); +criterion_group!(nano_benches, bench_nano_hex_digit, bench_nano_hex_byte); criterion_main!(decode_benches, micro_benches, nano_benches); diff --git a/src/lib.rs b/src/lib.rs index 434b5d6..ac1ae51 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![feature(array_chunks)] #![feature(const_slice_index)] +#![feature(const_trait_impl)] #![feature(extend_one)] #![feature(generic_const_exprs)] #![feature(int_log)] @@ -9,19 +10,26 @@ #![feature(const_maybe_uninit_array_assume_init)] #![feature(const_maybe_uninit_uninit_array)] #![feature(new_uninit)] - #![feature(portable_simd)] +pub(crate) mod util; + +pub(crate) mod simd; + use std::mem::MaybeUninit; use std::simd::*; +// use the maximum batch size that would be supported by AVX-512 +//pub const SIMD_WIDTH: usize = 512; +pub const SIMD_WIDTH: usize = 256; + /// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error). -pub const WIDE_BATCH_SIZE: usize = 512 / 16; +pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16; /// The batch size used for the hex digits. -// use the maximum batch size that would be supported by AVX-512 pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2; +#[inline] const fn alternating_mask(first_bias: bool) -> [bool; N] { let mut mask = [false; N]; let mut i = 0; @@ -39,6 +47,7 @@ const fn alternating_mask(first_bias: bool) -> [bool; N] { mask } +#[inline] const fn msb_lsb_indices() -> [usize; N] { if N % 2 != 0 { panic!("Illegal N"); @@ -55,6 +64,7 @@ const fn msb_lsb_indices() -> [usize; N] { indices } +#[inline] const fn alternating_indices(first_bias: bool) -> [usize; N] { let mut indices = [0; N]; let mut i = 0; @@ -75,6 +85,10 @@ const fn alternating_indices(first_bias: bool) -> [usize; N] { const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true); const LSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(false); +pub const INVALID_BIT: u8 = 0b1000_0000; + +pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000; + const ASCII_DIGITS: [u8; 256] = { let mut digits = [0u8; 256]; let mut i = u8::MIN; @@ -90,7 +104,7 @@ const ASCII_DIGITS: [u8; 256] = { DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN, LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN, UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN, - _ => 255, + _ => INVALID_BIT, }; i += 1; @@ -98,7 +112,7 @@ const ASCII_DIGITS: [u8; 256] = { digits }; -/// Returns 255 if invalid. Based on `char.to_digit()` in the stdlib. +/// Returns [`INVALID_BIT`] if invalid. Based on `char.to_digit()` in the stdlib. #[inline] pub const fn hex_digit(ascii: u8) -> u8 { // use std::ops::RangeInclusive; @@ -112,7 +126,7 @@ pub const fn hex_digit(ascii: u8) -> u8 { // DIGIT_MIN..=DIGIT_MAX => ascii - DIGIT_MIN, // LOWER_MIN..=LOWER_MAX => 10 + ascii - LOWER_MIN, // UPPER_MIN..=UPPER_MAX => 10 + ascii - UPPER_MIN, - // _ => 255, + // _ => INVALID_BIT, // } ASCII_DIGITS[ascii as usize] // let mut digit = ascii.wrapping_sub('0' as u8); @@ -124,69 +138,162 @@ pub const fn hex_digit(ascii: u8) -> u8 { // if digit < 6 { // return digit + 10; // } - // return 255; + // return INVALID_BIT; } #[inline(always)] -pub fn hex_digits(ascii: Simd) -> Simd where LaneCount: SupportedLaneCount { - unsafe { Simd::gather_select_unchecked(&ASCII_DIGITS, Mask::splat(true), ascii.cast(), Simd::splat(0)) } +pub fn hex_digit_simd(ascii: Simd) -> Simd +where + LaneCount: SupportedLaneCount, +{ + unsafe { + Simd::gather_select_unchecked( + &ASCII_DIGITS, + Mask::splat(true), + ascii.cast(), + simd::splat_0::(), + ) + } } /// Parses an ascii hex byte. -#[inline] +#[inline(always)] pub const fn hex_byte(msb: u8, lsb: u8) -> Option { let msb = hex_digit(msb); let lsb = hex_digit(lsb); - if (msb | lsb) == 255 { + // second is faster (perhaps it pipelines better?) + //if (msb | lsb) & INVALID_BIT != 0 { + if (msb & INVALID_BIT) | (lsb & INVALID_BIT) != 0 { return None; } Some(msb << 4 | lsb) } -/// Parses an ascii hex byte. Any value > [`u8::MAX`] is invalid. -#[inline] -pub const fn hex_byte_niched(msb: u8, lsb: u8) -> u16 { - let msb = hex_digit(msb) as u16; - let lsb = hex_digit(lsb) as u16; - (msb << 4) | (lsb & 0xf) | ((lsb & 0xf0) << 8) -} +/// A decoder for a single hex byte. +#[const_trait] +pub trait HexByteDecoder { + /// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input. + fn decode_unpacked(hi: u8, lo: u8) -> u16; -/// Parses an ascii hex byte. -#[inline] -pub const fn hex_byte_packed([msb, lsb]: &[u8; 2]) -> Option { - let msb = hex_digit(*msb); - let lsb = hex_digit(*lsb); - if (msb | lsb) == 255 { - return None; + /// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input. + #[inline(always)] + fn decode_packed([hi, lo]: &[u8; 2]) -> u16 { + Self::decode_unpacked(*hi, *lo) } - Some(msb << 4 | lsb) } -/// Parses an ascii hex byte. Any value > [`u8::MAX`] is invalid. -#[inline] -pub const fn hex_byte_packed_niched([msb, lsb]: &[u8; 2]) -> u16 { - let msb = hex_digit(*msb) as u16; - let lsb = hex_digit(*lsb) as u16; - (msb << 4) | (lsb & 0xf) | ((lsb & 0xf0) << 8) +/// A decoder for a sized batch of hex bytes. +pub trait HexByteSimdDecoder { + /// Parses an ascii hex byte. Any element of the return value exceeding [`u8::MAX`] indicates invalid input. + fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option>; } -#[inline] -pub fn hex_byte_simd(hex_bytes: Simd) -> Simd { - let hex_digits = hex_digits(hex_bytes); - //println!("hex_digits: {hex_digits:04x?}"); - //println!("MSB_INDICES: {MSB_INDICES:02?}"); - //println!("LSB_INDICES: {LSB_INDICES:02?}"); - let msb = simd_swizzle!(hex_digits, MSB_INDICES); - let lsb = simd_swizzle!(hex_digits, LSB_INDICES); - //println!("msb: {msb:04x?}"); - //println!("lsb: {lsb:04x?}"); - let msb = msb.cast::(); - let lsb = lsb.cast::(); - msb << Simd::splat(4) | lsb | ((lsb & Simd::splat(0xf0)) << Simd::splat(8)) +pub struct HexByteDecoderA; + +impl const HexByteDecoder for HexByteDecoderA { + #[inline(always)] + fn decode_unpacked(hi: u8, lo: u8) -> u16 { + let hi = hex_digit(hi) as u16; + let lo = hex_digit(lo) as u16; + // might these these masks allow the ORs the be pipelined more efficiently? + (hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8) + } + + #[inline(always)] + fn decode_packed([hi, lo]: &[u8; 2]) -> u16 { + let hi = hex_digit(*hi) as u16; + let lo = hex_digit(*lo) as u16; + (hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8) + } } +impl HexByteSimdDecoder for HexByteDecoderA { + #[inline(always)] + fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option> { + let hex_digits = hex_digit_simd::(Simd::from_array(hi_los)); + if ((hex_digits & simd::splat_n::(INVALID_BIT)) + .simd_ne(simd::splat_0::())) + .any() + { + return None; + } + let msb = simd_swizzle!(hex_digits, MSB_INDICES); + let lsb = simd_swizzle!(hex_digits, LSB_INDICES); + /*let msb = msb.cast::(); + let lsb = lsb.cast::(); + let buf = msb << simd::splat_n::(4) | lsb | ((lsb & simd::splat_n::(0xf0)) << simd::splat_n::(8)); + if buf.simd_gt(simd::splat_n::(u8::MAX as u16)).any() { + return None; + } + Some(buf.cast::())*/ + Some((msb << simd::splat_n::(4)) | lsb) + } +} + +pub struct HexByteDecoderB; + +impl const HexByteDecoder for HexByteDecoderB { + util::defer_impl! { + => HexByteDecoderA; + + //fn decode_unpacked(hi: u8, lo: u8) -> u16; + + //fn decode_packed(hi_lo: &[u8; 2]) -> u16; + } + + #[inline(always)] + fn decode_unpacked(hi: u8, lo: u8) -> u16 { + let lo = hex_digit(lo) as u16; + let hi = hex_digit(hi) as u16; + // kind of bizarre: changing the order of these decreases perf by 6-12% + (hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8) + } + + #[inline(always)] + fn decode_packed([hi, lo]: &[u8; 2]) -> u16 { + let lo = hex_digit(*lo) as u16; + let hi = hex_digit(*hi) as u16; + (hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8) + } +} + +impl HexByteSimdDecoder for HexByteDecoderB { + util::defer_impl! { + => HexByteDecoderA; + + //fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option>; + } + + #[inline(always)] + fn decode_simd(mut hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option> { + for b in hi_los.iter_mut() { + *b = hex_digit(*b); + } + let hex_digits = Simd::from_array(hi_los); + if (hex_digits & simd::splat_n::(INVALID_BIT)) + .simd_ne(simd::splat_0::()) + .any() + { + //if hex_digits.simd_eq(simd::splat_n::(INVALID_BIT)).any() { + return None; + } + let msb = simd_swizzle!(hex_digits, MSB_INDICES); + let lsb = simd_swizzle!(hex_digits, LSB_INDICES); + let mut v = Simd::from_array([0u8; WIDE_BATCH_SIZE]); + for (i, v) in v.as_mut_array().iter_mut().enumerate() { + let hi = unsafe { *msb.as_array().get_unchecked(i) }; + let lo = unsafe { *lsb.as_array().get_unchecked(i) }; + *v = (hi << 4) | lo; + } + Some(v) + //msb << simd::splat_n::(4) | lsb | ((lsb & simd::splat_n::(0xf0)) << simd::splat_n::(8)) + } +} + +pub type HBD = HexByteDecoderB; + pub mod conv { - use std::simd::{Simd, LaneCount, SupportedLaneCount}; + use std::simd::{LaneCount, Simd, SupportedLaneCount}; /*trait Size { const N: usize; @@ -241,19 +348,29 @@ pub mod conv { size_of_impl!(u64 = SizeU64);*/ #[allow(non_camel_case_types, non_snake_case)] - union u8_u16 where [u8; N_u16 * 2]: { + union u8_u16 + where + [u8; N_u16 * 2]:, + { u8: [u8; N_u16 * 2], u16: [u16; N_u16], } #[allow(non_camel_case_types, non_snake_case)] - union u8x2_u8 where [u8; N_u8x2 * 2]: { + union u8x2_u8 + where + [u8; N_u8x2 * 2]:, + { u8x2: [[u8; 2]; N_u8x2], u8: [u8; N_u8x2 * 2], } #[allow(non_camel_case_types, non_snake_case)] - union SimdU8_SimdU16 where LaneCount<{ N_U16 * 2 }>: SupportedLaneCount, LaneCount: SupportedLaneCount { + union SimdU8_SimdU16 + where + LaneCount<{ N_U16 * 2 }>: SupportedLaneCount, + LaneCount: SupportedLaneCount, + { SimdU8: Simd, SimdU16: Simd, } @@ -269,10 +386,12 @@ pub mod conv { } #[inline(always)] - pub const fn simdu8_to_simdu16(a: Simd) -> Simd + pub const fn simdu8_to_simdu16( + a: Simd, + ) -> Simd where LaneCount<{ N_OUT * 2 }>: SupportedLaneCount, - LaneCount: SupportedLaneCount + LaneCount: SupportedLaneCount, { unsafe { SimdU8_SimdU16 { SimdU8: a }.SimdU16 } } @@ -296,16 +415,15 @@ const fn align_up_to(n: usize) -> usize { return (n + (N - 1)) >> shift << shift; } - macro_rules! decode_hex_bytes_non_vectored { ($i:ident, $ascii:ident, $bytes:ident, $o:expr) => {{ while $i < $ascii.len() { match unsafe { hex_byte(*$ascii.get_unchecked($i), *$ascii.get_unchecked($i + 1)) } { Some(b) => unsafe { *$bytes.get_unchecked_mut($o + ($i >> 1)) = MaybeUninit::new(b) }, None => { - println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char); + //println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char); return false - }, + } } $i += 2; } @@ -314,29 +432,34 @@ macro_rules! decode_hex_bytes_non_vectored { #[inline(always)] fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit]) -> bool { - debug_assert_eq!(ascii.len() >> 1 << 1, ascii.len(), "len of ascii is not a multiple of 2"); - debug_assert_eq!(ascii.len() >> 1, bytes.len(), "len of ascii is not twice that of bytes"); - const VECTORED: bool = true; - if VECTORED { - union Aligned { - bytes: [u8; DIGIT_BATCH_SIZE], - simd: Simd, - } - + debug_assert_eq!( + ascii.len() >> 1 << 1, + ascii.len(), + "len of ascii is not a multiple of 2" + ); + debug_assert_eq!( + ascii.len() >> 1, + bytes.len(), + "len of ascii is not twice that of bytes" + ); + const VECTORED_A: bool = false; + const VECTORED_B: bool = false; + const VECTORED_C: bool = false; + if VECTORED_A { let mut i = 0; while i < align_down_to::(ascii.len()) { - let slice = unsafe { &*(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE]) }; - let aligned = Aligned { bytes: *slice }; - let buf = hex_byte_simd(unsafe { aligned.simd }); - if buf > Simd::splat(u8::MAX as u16) { - println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE); - return false; - } - let buf = buf.cast::(); + let buf = HBD::decode_simd(unsafe { + *(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE]) + }); + let buf = match buf { + Some(buf) => buf, + None => return false, + }; let mut j = 0; while j < DIGIT_BATCH_SIZE { unsafe { - *bytes.get_unchecked_mut((i >> 1) + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j)) + *bytes.get_unchecked_mut((i >> 1) + j) = + MaybeUninit::new(*buf.as_array().get_unchecked(j)) }; j += 1; } @@ -344,28 +467,30 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit]) -> bo } decode_hex_bytes_non_vectored!(i, ascii, bytes, 0); - } else if false { - let (ascii_pre, ascii_simd, ascii_post) = unsafe { ascii.align_to::>() }; + } else if VECTORED_B { + let (ascii_pre, ascii_simd, ascii_post) = + unsafe { ascii.align_to::>() }; - assert_eq!(ascii_pre.len() % 2, 0); - assert_eq!(ascii_post.len() % 2, 0); + debug_assert_eq!(ascii_pre.len() % 2, 0); + debug_assert_eq!(ascii_post.len() % 2, 0); let mut i = 0; decode_hex_bytes_non_vectored!(i, ascii_pre, bytes, 0); let mut i = 0; while i < ascii_simd.len() { - let buf = hex_byte_simd(unsafe { *ascii_simd.get_unchecked(i) }); - if buf > Simd::splat(u8::MAX as u16) { - println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE); - return false; - } - let buf = buf.cast::(); + // this to_array and any subsequent from_array should be eliminated anyway + let buf = HBD::decode_simd(unsafe { ascii_simd.get_unchecked(i) }.to_array()); + let buf = match buf { + Some(buf) => buf, + None => return false, + }; let mut j = 0; let k = ascii_pre.len() + i * DIGIT_BATCH_SIZE; while j < DIGIT_BATCH_SIZE { unsafe { - *bytes.get_unchecked_mut(k + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j)) + *bytes.get_unchecked_mut(k + j) = + MaybeUninit::new(*buf.as_array().get_unchecked(j)) }; j += 1; } @@ -375,6 +500,28 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit]) -> bo let mut i = 0; let k = ascii.len() - ascii_post.len(); decode_hex_bytes_non_vectored!(i, ascii_post, bytes, k); + } else if VECTORED_C { + let mut i = 0; + while i < align_down_to::(ascii.len()) { + let buf = HBD::decode_simd(unsafe { + *(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE]) + }); + let buf = match buf { + Some(buf) => buf, + None => return false, + }; + let mut j = 0; + while j < DIGIT_BATCH_SIZE { + unsafe { + *bytes.get_unchecked_mut((i >> 1) + j) = + MaybeUninit::new(*buf.as_array().get_unchecked(j)) + }; + j += 1; + } + i += DIGIT_BATCH_SIZE; + } + + decode_hex_bytes_non_vectored!(i, ascii, bytes, 0); } else { let mut i = 0; decode_hex_bytes_non_vectored!(i, ascii, bytes, 0); @@ -391,14 +538,17 @@ pub const fn hex_bytes_sized_const(ascii: &[u8; N * 2]) -> Optio } else { let mut bytes = MaybeUninit::::uninit_array::(); let mut i = 0; - while i < bytes.len() { - match hex_byte(unsafe { *ascii.get_unchecked(i << 1) }, unsafe { - *ascii.get_unchecked((i << 1) + 1) + while i < N * 2 { + if i >> 1 >= bytes.len() { + unsafe { std::hint::unreachable_unchecked() }; + } + match hex_byte(unsafe { *ascii.get_unchecked(i) }, unsafe { + *ascii.get_unchecked(i + 1) }) { - Some(b) => bytes[i] = MaybeUninit::new(b), + Some(b) => bytes[i >> 1] = MaybeUninit::new(b), None => return None, } - i += 1; + i += 2; } Some(unsafe { MaybeUninit::array_assume_init(bytes) }) } @@ -418,12 +568,12 @@ pub fn hex_bytes_sized(ascii: &[u8; N * 2]) -> Option<[u8; N]> { } } +#[inline] pub fn hex_bytes_sized_heap(ascii: &[u8; N * 2]) -> Option> { if N == 0 { Some(Box::new([0u8; N])) } else { - let mut bytes: Box<[MaybeUninit; N]> = - unsafe { Box::<[MaybeUninit; N]>::new_uninit().assume_init() }; + let mut bytes = unsafe { Box::<[MaybeUninit; N]>::new_uninit().assume_init() }; if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) { Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8; N]) }) } else { @@ -432,19 +582,21 @@ pub fn hex_bytes_sized_heap(ascii: &[u8; N * 2]) -> Option Option> { let len = ascii.len() >> 1; if len << 1 != ascii.len() { return None; } - let mut bytes = Box::<[u8]>::new_uninit_slice(len); + let mut bytes = Box::new_uninit_slice(len); if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) { - Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) }) + Some(unsafe { Box::<[_]>::assume_init(bytes) }) } else { None } } +#[inline] pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option> { let len = ascii.len() >> 1; if len << 1 != ascii.len() { @@ -462,9 +614,10 @@ pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option> { return None; } } - Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) }) + Some(unsafe { Box::<[_]>::assume_init(bytes) }) } +#[inline] pub fn hex_bytes_dyn_unsafe_iter_niched(ascii: &[u8]) -> Option> { let len = ascii.len() >> 1; if len << 1 != ascii.len() { @@ -473,18 +626,18 @@ pub fn hex_bytes_dyn_unsafe_iter_niched(ascii: &[u8]) -> Option> { let mut bytes = Box::<[u8]>::new_uninit_slice(len); for (i, b) in ascii .array_chunks::<2>() - .map(|[msb, lsb]| hex_byte_niched(*msb, *lsb)) + .map(HBD::decode_packed) .enumerate() { - if b & 0xff_00 == 0 { - unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b as u8) }; - } else { + if b & WIDE_INVALID_BIT != 0 { return None; } + unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b as u8) }; } - Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) }) + Some(unsafe { Box::<[_]>::assume_init(bytes) }) } +#[inline] pub fn hex_bytes_dyn(ascii: &[u8]) -> Option> { let iter = ascii.array_chunks::<2>(); if iter.remainder().len() != 0 { @@ -495,13 +648,20 @@ pub fn hex_bytes_dyn(ascii: &[u8]) -> Option> { .map(|v| v.into_boxed_slice()) } +#[inline] pub fn hex_bytes_dyn_niched(ascii: &[u8]) -> Option> { let iter = ascii.array_chunks::<2>(); if iter.remainder().len() != 0 { return None; } - iter.map(|[msb, lsb]| hex_byte_niched(*msb, *lsb)) - .map(|b| if b & 0xff_00 == 0 { Some(b as u8) } else { None }) + iter.map(HBD::decode_packed) + .map(|b| { + if b & WIDE_INVALID_BIT != 0 { + None + } else { + Some(b as u8) + } + }) .collect::>>() .map(|v| v.into_boxed_slice()) } @@ -522,8 +682,21 @@ mod test { } const SAMPLES: &[Sample] = &[ - Sample { bytes: BYTES, hex_bytes: HEX_BYTES }, - Sample { bytes: LONG_BYTES, hex_bytes: LONG_HEX_BYTES }, + Sample { + bytes: BYTES, + hex_bytes: HEX_BYTES, + }, + Sample { + bytes: LONG_BYTES, + hex_bytes: LONG_HEX_BYTES, + }, + ]; + + const INVALID_SAMPLES: &[&[u8]] = &[ + b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722G", + b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E7365637465747572GE", + b"446F6C6F72756D2064697374696E6374696G20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E", + b"446F6C6F72756D2064697374696E637469GF20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E", ]; #[test] @@ -547,8 +720,10 @@ mod test { #[test] fn test_hex_digit_simd() { const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[ - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', - '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', + '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', + 'F', + // '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' ]; let mut set8 = [0u8; DIGIT_BATCH_SIZE]; @@ -559,7 +734,10 @@ mod test { for i in 0..(DIGIT_BATCH_SIZE) { sete[i] = (i as u8) % 16; } - assert_eq!(hex_digits(Simd::from_array(set8)), Simd::from_array(sete)); + assert_eq!( + hex_digit_simd::(Simd::from_array(set8)), + Simd::from_array(sete) + ); } #[test] @@ -576,7 +754,12 @@ mod test { for (hb, b) in HEX_BYTES_VALID { assert_eq!(hex_byte(hb[0], hb[1]), Some(*b)); - assert_eq!(hex_byte_niched(hb[0], hb[1]), *b as u16); + + assert_eq!(HexByteDecoderA::decode_unpacked(hb[0], hb[1]), *b as u16); + assert_eq!(HexByteDecoderB::decode_unpacked(hb[0], hb[1]), *b as u16); + + assert_eq!(HexByteDecoderA::decode_packed(hb), *b as u16); + assert_eq!(HexByteDecoderB::decode_packed(hb), *b as u16); } const HEX_BYTES_INVALID: &[[u8; 2]] = &[ @@ -591,39 +774,47 @@ mod test { for hb in HEX_BYTES_INVALID { assert_eq!(hex_byte(hb[0], hb[1]), None); - assert!(hex_byte_niched(hb[0], hb[1]) & 0xff_00 != 0); + assert_ne!( + HexByteDecoderA::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT, + 0 + ); + assert_ne!( + HexByteDecoderB::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT, + 0 + ); + + assert_ne!(HexByteDecoderA::decode_packed(hb) & WIDE_INVALID_BIT, 0); + assert_ne!(HexByteDecoderB::decode_packed(hb) & WIDE_INVALID_BIT, 0); } } #[test] fn test_hex_byte_simd() { const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [ - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", + *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11", + *b"ef", *b"fe", *b"0f", *b"f0", + *b"34", + // *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", - - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", - - *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", + // *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", ]; const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [ - 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, + 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, + 0xf0, + 0x34, + // 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, - 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, - - 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, - - 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, + // 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, ]; - let hex_bytes = Simd::from_array(conv::u8x2_to_u8(HEX_BYTES_VALID)); - let bytes = BYTES_VALID.map(|b| b as u16); - let bytes = Simd::from_array(bytes); + let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID); + let bytes = Simd::from_array(BYTES_VALID); println!("hex_bytes: {HEX_BYTES_VALID:02x?}"); println!("hex_bytes: {hex_bytes:02x?}"); println!("bytes: {BYTES_VALID:02x?}"); println!("bytes: {bytes:04x?}"); - assert_eq!(hex_byte_simd(hex_bytes), bytes); + assert_eq!(HexByteDecoderA::decode_simd(hex_bytes), Some(bytes)); + assert_eq!(HexByteDecoderB::decode_simd(hex_bytes), Some(bytes)); /*const HEX_BYTES_INVALID: &[[u8; 2]] = &[ ['f' as u8, 'g' as u8], @@ -637,14 +828,27 @@ mod test { for hb in HEX_BYTES_INVALID { assert_eq!(hex_byte(hb[0], hb[1]), None); - assert!(hex_byte_niched(hb[0], hb[1]) & 0xff_00 != 0); + assert!(hex_byte_niched(hb[0], hb[1]) & WIDE_INVALID_BIT != 0); }*/ } fn test_f(f: fn(&[u8]) -> Option>) { for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() { let result = f(hex_bytes); - assert_eq!(Some(*bytes), result.as_ref().map(Box::as_ref), "Sample {i} did not decode correctly"); + assert_eq!( + Some(*bytes), + result.as_ref().map(Box::as_ref), + "Sample {i} did not decode correctly" + ); + } + + for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() { + let result = f(hex_bytes); + assert_eq!( + None, + result.as_ref().map(Box::as_ref), + "Sample {i} did not decode correctly" + ); } } diff --git a/src/simd.rs b/src/simd.rs new file mode 100644 index 0000000..a964be5 --- /dev/null +++ b/src/simd.rs @@ -0,0 +1,89 @@ +use std::simd::{LaneCount, Simd, SupportedLaneCount}; + +use crate::util::cast; + +pub trait SimdSplatZero { + fn splat_zero() -> Simd + where + LaneCount: SupportedLaneCount; +} + +pub trait SimdSplatN { + fn splat_n(n: u8) -> Simd + where + LaneCount: SupportedLaneCount; +} + +pub struct SimdOps; + +const W_128: usize = 128 / 8; +const W_256: usize = 256 / 8; +const W_512: usize = 512 / 8; + +macro_rules! specialized { + ($LANES:ident, $trait:ident { + $( + fn $name:ident($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { + $( + $width:pat_param $( if $cfg:meta )? => $impl:expr + ),+ + $(,)? + } + )* + }) => { + impl $trait<$LANES> for SimdOps { + $( + #[inline(always)] + fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? { + // abusing const generics to specialize without the unsoundness of real specialization! + match $LANES { + $( + $( #[cfg( $cfg )] )? + $width => $impl + ),+ + } + } + )* + } + }; +} + +specialized!(LANES, SimdSplatN { + fn splat_n(n: u8) -> Simd where [LaneCount: SupportedLaneCount] { + W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) }, + W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) }, + W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) }, + W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) }, + W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) }, + W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) }, + _ => Simd::splat(n), + } +}); + +specialized!(LANES, SimdSplatZero { + fn splat_zero() -> Simd where [LaneCount: SupportedLaneCount] { + W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) }, + W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) }, + W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) }, + W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) }, + W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) }, + W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) }, + _ => >::splat_n(0), + } +}); + +#[inline(always)] +pub fn splat_0() -> Simd +where + LaneCount: SupportedLaneCount, +{ + >::splat_zero() +} + +#[inline(always)] +pub fn splat_n(n: u8) -> Simd +where + LaneCount: SupportedLaneCount, +{ + >::splat_n(n) +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..8bf5fd8 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,55 @@ +#[doc(hidden)] +#[macro_export] +macro_rules! __defer_impl { + ( + => $impl:ident; + + $( fn $name:ident($( $pname:ident: $pty:ty ),*) -> $rty:ty; )* + ) => { + $( + #[inline(always)] + fn $name($( $pname: $pty ),*) -> $rty { + <$impl>::$name($( $pname ),*) + } + )* + }; +} + +pub use __defer_impl as defer_impl; + +#[inline(always)] +#[cold] +pub fn cold() {} + +#[inline] +pub fn likely(b: bool) -> bool { + if !b { + cold() + } + b +} + +#[inline] +pub fn unlikely(b: bool) -> bool { + if b { + cold() + } + b +} + +/// Like transmute, but implemented via a union so that we can use it in situations where +/// transmute's "safety" restrictions are too strict and uninformed (i.e. we can prove it is safe). +#[inline(always)] +pub unsafe fn cast(a: A) -> B { + union Cast { + a: std::mem::ManuallyDrop, + b: std::mem::ManuallyDrop, + } + + std::mem::ManuallyDrop::into_inner( + Cast { + a: std::mem::ManuallyDrop::new(a), + } + .b, + ) +}