This commit is contained in:
Michael Pfaff 2022-10-30 15:35:17 -04:00
parent 0942be2634
commit 379fa71d78
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
5 changed files with 574 additions and 163 deletions

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies] [dependencies]
[dev-dependencies] [dev-dependencies]
criterion = "0.3" criterion = { version = "0.4", features = [ "real_blackbox" ] }
rand = "0.8.5" rand = "0.8.5"
[[bench]] [[bench]]

View File

@ -1,6 +1,5 @@
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
#![feature(new_uninit)] #![feature(new_uninit)]
#![feature(portable_simd)] #![feature(portable_simd)]
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
@ -15,15 +14,21 @@ const ASCII_BYTES_LONG: &[u8; 256] = b"Donald J. Trump!Donald J. Trump!Donald J.
const HEX_BYTES_LONG: &[u8; ASCII_BYTES_LONG.len() * 2] = b"446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021"; const HEX_BYTES_LONG: &[u8; ASCII_BYTES_LONG.len() * 2] = b"446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021";
macro_rules! name { macro_rules! name {
($group:ident, $f:literal) => {
std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str())
};
($group:literal, $f:literal) => { ($group:literal, $f:literal) => {
concat!("[", $group, "] - ", $f) concat!("[", $group, "] - ", $f)
}; };
($group:expr, $f:literal) => {
std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str())
};
($group:expr, $f:expr) => {
std::boxed::Box::leak(format!(name!("{}", "{}"), $group, $f).into_boxed_str())
};
($group:literal, $f:expr) => {
std::boxed::Box::leak(format!(name!($group, "{}"), $f).into_boxed_str())
};
} }
//#[track_caller] #[track_caller]
fn test_sized<const N: usize, const HEAP_ONLY: bool>(hex_bytes: &[u8; N * 2], bytes: &[u8; N]) fn test_sized<const N: usize, const HEAP_ONLY: bool>(hex_bytes: &[u8; N * 2], bytes: &[u8; N])
where where
[(); N * 2]:, [(); N * 2]:,
@ -50,7 +55,7 @@ where
); );
} }
//#[track_caller] #[track_caller]
#[inline] #[inline]
fn test(hex_bytes: &[u8], bytes: &[u8]) { fn test(hex_bytes: &[u8], bytes: &[u8]) {
assert_eq!(hex_bytes.len(), bytes.len() * 2); assert_eq!(hex_bytes.len(), bytes.len() * 2);
@ -261,62 +266,120 @@ pub fn bench_1_6m(c: &mut Criterion) {
benchmark_sized::<LEN, true>("1.6m", &hex_bytes, c); benchmark_sized::<LEN, true>("1.6m", &hex_bytes, c);
} }
pub fn bench_micro_hex_byte(c: &mut Criterion) { pub fn bench_micro_hex_digit(c: &mut Criterion) {
use std::simd::Simd; use std::simd::Simd;
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [ const HEX_DIGITS_VALID: [u8; DIGIT_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4, 0xf, 0xf,
0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3,
0x4,
// 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4,
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", // 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4,
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
]; ];
let hex_bytes = Simd::from_array(conv::u8x2_to_u8(black_box(HEX_BYTES_VALID))); let hex_digits = Simd::from_array(black_box(HEX_DIGITS_VALID));
c.bench_function(name!("micro", "hex_digit"), |b| {
b.iter(|| {
for b in black_box(HEX_DIGITS_VALID) {
black_box(hex_digit(b));
}
})
});
c.bench_function(name!("micro", "hex_digit_simd"), |b| {
b.iter(|| hex_digit_simd::<DIGIT_BATCH_SIZE>(hex_digits))
});
}
pub fn bench_micro_hex_byte(c: &mut Criterion) {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11",
*b"ef", *b"fe", *b"0f", *b"f0",
*b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
fn bench_decoder<T: HexByteDecoder + HexByteSimdDecoder>(c: &mut Criterion, name: &str) {
let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID);
c.bench_function(name!("micro", format!("{name}::decode_packed")), |b| {
b.iter(|| {
for b in black_box(HEX_BYTES_VALID) {
black_box(T::decode_packed(&b));
}
})
});
c.bench_function(name!("micro", format!("{name}::decode_unpacked")), |b| {
b.iter(|| {
for [hi, lo] in black_box(HEX_BYTES_VALID) {
black_box(T::decode_unpacked(hi, lo));
}
})
});
c.bench_function(name!("micro", format!("{name}::decode_simd")), |b| {
b.iter(|| T::decode_simd(black_box(hex_bytes)))
});
}
c.bench_function(name!("micro", "hex_byte"), |b| { c.bench_function(name!("micro", "hex_byte"), |b| {
b.iter(|| { b.iter(|| {
for b in black_box(HEX_BYTES_VALID) { for b in black_box(HEX_BYTES_VALID) {
hex_byte(b[0], b[1]); black_box(hex_byte(b[0], b[1]));
} }
}) })
}); });
c.bench_function(name!("micro", "hex_byte_niched"), |b| { bench_decoder::<HexByteDecoderA>(c, stringify!(HexByteDecoderA));
b.iter(|| { bench_decoder::<HexByteDecoderB>(c, stringify!(HexByteDecoderB));
for b in black_box(HEX_BYTES_VALID) {
hex_byte_niched(b[0], b[1]);
} }
})
});
c.bench_function(name!("micro", "hex_byte_simd"), |b| { pub fn bench_nano_hex_digit(c: &mut Criterion) {
b.iter(|| hex_byte_simd(hex_bytes)) let digit = black_box('5' as u8);
c.bench_function(name!("nano", "hex_digit"), |b| b.iter(|| hex_digit(digit)));
c.bench_function(name!("nano", "hex_digit +bb"), |b| {
b.iter(|| hex_digit(black_box(digit)))
}); });
} }
pub fn bench_nano_hex_byte(c: &mut Criterion) { pub fn bench_nano_hex_byte(c: &mut Criterion) {
let digit = black_box(['5' as u8, 'b' as u8]); const DIGITS: [u8; 2] = ['5' as u8, 'b' as u8];
let digit = black_box(DIGITS);
c.bench_function(name!("nano", "hex_byte"), |b| { c.bench_function(name!("nano", "hex_byte"), |b| {
b.iter(|| hex_byte(digit[0], digit[1])) b.iter(|| hex_byte(digit[0], digit[1]))
}); });
c.bench_function(name!("nano", "hex_byte_niched"), |b| { fn bench_decoder<T: HexByteDecoder + HexByteSimdDecoder>(c: &mut Criterion, name: &str) {
b.iter(|| hex_byte_niched(digit[0], digit[1])) let digit = black_box(DIGITS);
c.bench_function(name!("nano", format!("{name}::decode_packed")), |b| {
b.iter(|| {
black_box(T::decode_packed(&digit));
})
}); });
c.bench_function(name!("nano", "hex_byte +bb"), |b| { c.bench_function(name!("nano", format!("{name}::decode_unpacked")), |b| {
b.iter(|| hex_byte(black_box(digit[0]), black_box(digit[1]))) b.iter(|| {
}); black_box(T::decode_unpacked(
black_box(DIGITS[0]),
c.bench_function(name!("nano", "hex_byte_niched +bb"), |b| { black_box(DIGITS[1]),
b.iter(|| hex_byte_niched(black_box(digit[0]), black_box(digit[1]))) ));
})
}); });
} }
bench_decoder::<HexByteDecoderA>(c, stringify!(HexByteDecoderA));
bench_decoder::<HexByteDecoderB>(c, stringify!(HexByteDecoderB));
}
criterion_group!(decode_benches, bench_16, bench_256, bench_1_6m); criterion_group!(decode_benches, bench_16, bench_256, bench_1_6m);
criterion_group!(micro_benches, bench_micro_hex_byte); criterion_group!(micro_benches, bench_micro_hex_digit, bench_micro_hex_byte);
criterion_group!(nano_benches, bench_nano_hex_byte); criterion_group!(nano_benches, bench_nano_hex_digit, bench_nano_hex_byte);
criterion_main!(decode_benches, micro_benches, nano_benches); criterion_main!(decode_benches, micro_benches, nano_benches);

View File

@ -1,5 +1,6 @@
#![feature(array_chunks)] #![feature(array_chunks)]
#![feature(const_slice_index)] #![feature(const_slice_index)]
#![feature(const_trait_impl)]
#![feature(extend_one)] #![feature(extend_one)]
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
#![feature(int_log)] #![feature(int_log)]
@ -9,19 +10,26 @@
#![feature(const_maybe_uninit_array_assume_init)] #![feature(const_maybe_uninit_array_assume_init)]
#![feature(const_maybe_uninit_uninit_array)] #![feature(const_maybe_uninit_uninit_array)]
#![feature(new_uninit)] #![feature(new_uninit)]
#![feature(portable_simd)] #![feature(portable_simd)]
pub(crate) mod util;
pub(crate) mod simd;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::simd::*; use std::simd::*;
// use the maximum batch size that would be supported by AVX-512
//pub const SIMD_WIDTH: usize = 512;
pub const SIMD_WIDTH: usize = 256;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error). /// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = 512 / 16; pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
/// The batch size used for the hex digits. /// The batch size used for the hex digits.
// use the maximum batch size that would be supported by AVX-512
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2; pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
#[inline]
const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] { const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
let mut mask = [false; N]; let mut mask = [false; N];
let mut i = 0; let mut i = 0;
@ -39,6 +47,7 @@ const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
mask mask
} }
#[inline]
const fn msb_lsb_indices<const N: usize>() -> [usize; N] { const fn msb_lsb_indices<const N: usize>() -> [usize; N] {
if N % 2 != 0 { if N % 2 != 0 {
panic!("Illegal N"); panic!("Illegal N");
@ -55,6 +64,7 @@ const fn msb_lsb_indices<const N: usize>() -> [usize; N] {
indices indices
} }
#[inline]
const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] { const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] {
let mut indices = [0; N]; let mut indices = [0; N];
let mut i = 0; let mut i = 0;
@ -75,6 +85,10 @@ const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] {
const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true); const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true);
const LSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(false); const LSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(false);
pub const INVALID_BIT: u8 = 0b1000_0000;
pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000;
const ASCII_DIGITS: [u8; 256] = { const ASCII_DIGITS: [u8; 256] = {
let mut digits = [0u8; 256]; let mut digits = [0u8; 256];
let mut i = u8::MIN; let mut i = u8::MIN;
@ -90,7 +104,7 @@ const ASCII_DIGITS: [u8; 256] = {
DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN, DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN,
LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN, LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN,
UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN, UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN,
_ => 255, _ => INVALID_BIT,
}; };
i += 1; i += 1;
@ -98,7 +112,7 @@ const ASCII_DIGITS: [u8; 256] = {
digits digits
}; };
/// Returns 255 if invalid. Based on `char.to_digit()` in the stdlib. /// Returns [`INVALID_BIT`] if invalid. Based on `char.to_digit()` in the stdlib.
#[inline] #[inline]
pub const fn hex_digit(ascii: u8) -> u8 { pub const fn hex_digit(ascii: u8) -> u8 {
// use std::ops::RangeInclusive; // use std::ops::RangeInclusive;
@ -112,7 +126,7 @@ pub const fn hex_digit(ascii: u8) -> u8 {
// DIGIT_MIN..=DIGIT_MAX => ascii - DIGIT_MIN, // DIGIT_MIN..=DIGIT_MAX => ascii - DIGIT_MIN,
// LOWER_MIN..=LOWER_MAX => 10 + ascii - LOWER_MIN, // LOWER_MIN..=LOWER_MAX => 10 + ascii - LOWER_MIN,
// UPPER_MIN..=UPPER_MAX => 10 + ascii - UPPER_MIN, // UPPER_MIN..=UPPER_MAX => 10 + ascii - UPPER_MIN,
// _ => 255, // _ => INVALID_BIT,
// } // }
ASCII_DIGITS[ascii as usize] ASCII_DIGITS[ascii as usize]
// let mut digit = ascii.wrapping_sub('0' as u8); // let mut digit = ascii.wrapping_sub('0' as u8);
@ -124,69 +138,162 @@ pub const fn hex_digit(ascii: u8) -> u8 {
// if digit < 6 { // if digit < 6 {
// return digit + 10; // return digit + 10;
// } // }
// return 255; // return INVALID_BIT;
} }
#[inline(always)] #[inline(always)]
pub fn hex_digits<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES> where LaneCount<LANES>: SupportedLaneCount { pub fn hex_digit_simd<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES>
unsafe { Simd::gather_select_unchecked(&ASCII_DIGITS, Mask::splat(true), ascii.cast(), Simd::splat(0)) } where
LaneCount<LANES>: SupportedLaneCount,
{
unsafe {
Simd::gather_select_unchecked(
&ASCII_DIGITS,
Mask::splat(true),
ascii.cast(),
simd::splat_0::<LANES>(),
)
}
} }
/// Parses an ascii hex byte. /// Parses an ascii hex byte.
#[inline] #[inline(always)]
pub const fn hex_byte(msb: u8, lsb: u8) -> Option<u8> { pub const fn hex_byte(msb: u8, lsb: u8) -> Option<u8> {
let msb = hex_digit(msb); let msb = hex_digit(msb);
let lsb = hex_digit(lsb); let lsb = hex_digit(lsb);
if (msb | lsb) == 255 { // second is faster (perhaps it pipelines better?)
//if (msb | lsb) & INVALID_BIT != 0 {
if (msb & INVALID_BIT) | (lsb & INVALID_BIT) != 0 {
return None; return None;
} }
Some(msb << 4 | lsb) Some(msb << 4 | lsb)
} }
/// Parses an ascii hex byte. Any value > [`u8::MAX`] is invalid. /// A decoder for a single hex byte.
#[inline] #[const_trait]
pub const fn hex_byte_niched(msb: u8, lsb: u8) -> u16 { pub trait HexByteDecoder {
let msb = hex_digit(msb) as u16; /// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
let lsb = hex_digit(lsb) as u16; fn decode_unpacked(hi: u8, lo: u8) -> u16;
(msb << 4) | (lsb & 0xf) | ((lsb & 0xf0) << 8)
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
Self::decode_unpacked(*hi, *lo)
}
} }
/// Parses an ascii hex byte. /// A decoder for a sized batch of hex bytes.
#[inline] pub trait HexByteSimdDecoder {
pub const fn hex_byte_packed([msb, lsb]: &[u8; 2]) -> Option<u8> { /// Parses an ascii hex byte. Any element of the return value exceeding [`u8::MAX`] indicates invalid input.
let msb = hex_digit(*msb); fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
let lsb = hex_digit(*lsb); }
if (msb | lsb) == 255 {
pub struct HexByteDecoderA;
impl const HexByteDecoder for HexByteDecoderA {
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let hi = hex_digit(hi) as u16;
let lo = hex_digit(lo) as u16;
// might these these masks allow the ORs the be pipelined more efficiently?
(hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let hi = hex_digit(*hi) as u16;
let lo = hex_digit(*lo) as u16;
(hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8)
}
}
impl HexByteSimdDecoder for HexByteDecoderA {
#[inline(always)]
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
let hex_digits = hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(hi_los));
if ((hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
.simd_ne(simd::splat_0::<DIGIT_BATCH_SIZE>()))
.any()
{
return None; return None;
} }
Some(msb << 4 | lsb)
}
/// Parses an ascii hex byte. Any value > [`u8::MAX`] is invalid.
#[inline]
pub const fn hex_byte_packed_niched([msb, lsb]: &[u8; 2]) -> u16 {
let msb = hex_digit(*msb) as u16;
let lsb = hex_digit(*lsb) as u16;
(msb << 4) | (lsb & 0xf) | ((lsb & 0xf0) << 8)
}
#[inline]
pub fn hex_byte_simd(hex_bytes: Simd<u8, DIGIT_BATCH_SIZE>) -> Simd<u16, WIDE_BATCH_SIZE> {
let hex_digits = hex_digits(hex_bytes);
//println!("hex_digits: {hex_digits:04x?}");
//println!("MSB_INDICES: {MSB_INDICES:02?}");
//println!("LSB_INDICES: {LSB_INDICES:02?}");
let msb = simd_swizzle!(hex_digits, MSB_INDICES); let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES); let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
//println!("msb: {msb:04x?}"); /*let msb = msb.cast::<u16>();
//println!("lsb: {lsb:04x?}");
let msb = msb.cast::<u16>();
let lsb = lsb.cast::<u16>(); let lsb = lsb.cast::<u16>();
msb << Simd::splat(4) | lsb | ((lsb & Simd::splat(0xf0)) << Simd::splat(8)) let buf = msb << simd::splat_n::<WIDE_BATCH_SIZE>(4) | lsb | ((lsb & simd::splat_n::<WIDE_BATCH_SIZE>(0xf0)) << simd::splat_n::<WIDE_BATCH_SIZE>(8));
if buf.simd_gt(simd::splat_n::<WIDE_BATCH_SIZE>(u8::MAX as u16)).any() {
return None;
}
Some(buf.cast::<u8>())*/
Some((msb << simd::splat_n::<WIDE_BATCH_SIZE>(4)) | lsb)
}
} }
pub struct HexByteDecoderB;
impl const HexByteDecoder for HexByteDecoderB {
util::defer_impl! {
=> HexByteDecoderA;
//fn decode_unpacked(hi: u8, lo: u8) -> u16;
//fn decode_packed(hi_lo: &[u8; 2]) -> u16;
}
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let lo = hex_digit(lo) as u16;
let hi = hex_digit(hi) as u16;
// kind of bizarre: changing the order of these decreases perf by 6-12%
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let lo = hex_digit(*lo) as u16;
let hi = hex_digit(*hi) as u16;
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
}
impl HexByteSimdDecoder for HexByteDecoderB {
util::defer_impl! {
=> HexByteDecoderA;
//fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
}
#[inline(always)]
fn decode_simd(mut hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
for b in hi_los.iter_mut() {
*b = hex_digit(*b);
}
let hex_digits = Simd::from_array(hi_los);
if (hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
.simd_ne(simd::splat_0::<DIGIT_BATCH_SIZE>())
.any()
{
//if hex_digits.simd_eq(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT)).any() {
return None;
}
let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
let mut v = Simd::from_array([0u8; WIDE_BATCH_SIZE]);
for (i, v) in v.as_mut_array().iter_mut().enumerate() {
let hi = unsafe { *msb.as_array().get_unchecked(i) };
let lo = unsafe { *lsb.as_array().get_unchecked(i) };
*v = (hi << 4) | lo;
}
Some(v)
//msb << simd::splat_n::<WIDE_BATCH_SIZE>(4) | lsb | ((lsb & simd::splat_n::<WIDE_BATCH_SIZE>(0xf0)) << simd::splat_n::<WIDE_BATCH_SIZE>(8))
}
}
pub type HBD = HexByteDecoderB;
pub mod conv { pub mod conv {
use std::simd::{Simd, LaneCount, SupportedLaneCount}; use std::simd::{LaneCount, Simd, SupportedLaneCount};
/*trait Size { /*trait Size {
const N: usize; const N: usize;
@ -241,19 +348,29 @@ pub mod conv {
size_of_impl!(u64 = SizeU64);*/ size_of_impl!(u64 = SizeU64);*/
#[allow(non_camel_case_types, non_snake_case)] #[allow(non_camel_case_types, non_snake_case)]
union u8_u16<const N_u16: usize> where [u8; N_u16 * 2]: { union u8_u16<const N_u16: usize>
where
[u8; N_u16 * 2]:,
{
u8: [u8; N_u16 * 2], u8: [u8; N_u16 * 2],
u16: [u16; N_u16], u16: [u16; N_u16],
} }
#[allow(non_camel_case_types, non_snake_case)] #[allow(non_camel_case_types, non_snake_case)]
union u8x2_u8<const N_u8x2: usize> where [u8; N_u8x2 * 2]: { union u8x2_u8<const N_u8x2: usize>
where
[u8; N_u8x2 * 2]:,
{
u8x2: [[u8; 2]; N_u8x2], u8x2: [[u8; 2]; N_u8x2],
u8: [u8; N_u8x2 * 2], u8: [u8; N_u8x2 * 2],
} }
#[allow(non_camel_case_types, non_snake_case)] #[allow(non_camel_case_types, non_snake_case)]
union SimdU8_SimdU16<const N_U16: usize> where LaneCount<{ N_U16 * 2 }>: SupportedLaneCount, LaneCount<N_U16>: SupportedLaneCount { union SimdU8_SimdU16<const N_U16: usize>
where
LaneCount<{ N_U16 * 2 }>: SupportedLaneCount,
LaneCount<N_U16>: SupportedLaneCount,
{
SimdU8: Simd<u8, { N_U16 * 2 }>, SimdU8: Simd<u8, { N_U16 * 2 }>,
SimdU16: Simd<u16, N_U16>, SimdU16: Simd<u16, N_U16>,
} }
@ -269,10 +386,12 @@ pub mod conv {
} }
#[inline(always)] #[inline(always)]
pub const fn simdu8_to_simdu16<const N_OUT: usize>(a: Simd<u8, { N_OUT * 2 }>) -> Simd<u16, N_OUT> pub const fn simdu8_to_simdu16<const N_OUT: usize>(
a: Simd<u8, { N_OUT * 2 }>,
) -> Simd<u16, N_OUT>
where where
LaneCount<{ N_OUT * 2 }>: SupportedLaneCount, LaneCount<{ N_OUT * 2 }>: SupportedLaneCount,
LaneCount<N_OUT>: SupportedLaneCount LaneCount<N_OUT>: SupportedLaneCount,
{ {
unsafe { SimdU8_SimdU16 { SimdU8: a }.SimdU16 } unsafe { SimdU8_SimdU16 { SimdU8: a }.SimdU16 }
} }
@ -296,16 +415,15 @@ const fn align_up_to<const N: usize>(n: usize) -> usize {
return (n + (N - 1)) >> shift << shift; return (n + (N - 1)) >> shift << shift;
} }
macro_rules! decode_hex_bytes_non_vectored { macro_rules! decode_hex_bytes_non_vectored {
($i:ident, $ascii:ident, $bytes:ident, $o:expr) => {{ ($i:ident, $ascii:ident, $bytes:ident, $o:expr) => {{
while $i < $ascii.len() { while $i < $ascii.len() {
match unsafe { hex_byte(*$ascii.get_unchecked($i), *$ascii.get_unchecked($i + 1)) } { match unsafe { hex_byte(*$ascii.get_unchecked($i), *$ascii.get_unchecked($i + 1)) } {
Some(b) => unsafe { *$bytes.get_unchecked_mut($o + ($i >> 1)) = MaybeUninit::new(b) }, Some(b) => unsafe { *$bytes.get_unchecked_mut($o + ($i >> 1)) = MaybeUninit::new(b) },
None => { None => {
println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char); //println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
return false return false
}, }
} }
$i += 2; $i += 2;
} }
@ -314,29 +432,34 @@ macro_rules! decode_hex_bytes_non_vectored {
#[inline(always)] #[inline(always)]
fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bool { fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bool {
debug_assert_eq!(ascii.len() >> 1 << 1, ascii.len(), "len of ascii is not a multiple of 2"); debug_assert_eq!(
debug_assert_eq!(ascii.len() >> 1, bytes.len(), "len of ascii is not twice that of bytes"); ascii.len() >> 1 << 1,
const VECTORED: bool = true; ascii.len(),
if VECTORED { "len of ascii is not a multiple of 2"
union Aligned { );
bytes: [u8; DIGIT_BATCH_SIZE], debug_assert_eq!(
simd: Simd<u8, DIGIT_BATCH_SIZE>, ascii.len() >> 1,
} bytes.len(),
"len of ascii is not twice that of bytes"
);
const VECTORED_A: bool = false;
const VECTORED_B: bool = false;
const VECTORED_C: bool = false;
if VECTORED_A {
let mut i = 0; let mut i = 0;
while i < align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) { while i < align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let slice = unsafe { &*(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE]) }; let buf = HBD::decode_simd(unsafe {
let aligned = Aligned { bytes: *slice }; *(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE])
let buf = hex_byte_simd(unsafe { aligned.simd }); });
if buf > Simd::splat(u8::MAX as u16) { let buf = match buf {
println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE); Some(buf) => buf,
return false; None => return false,
} };
let buf = buf.cast::<u8>();
let mut j = 0; let mut j = 0;
while j < DIGIT_BATCH_SIZE { while j < DIGIT_BATCH_SIZE {
unsafe { unsafe {
*bytes.get_unchecked_mut((i >> 1) + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j)) *bytes.get_unchecked_mut((i >> 1) + j) =
MaybeUninit::new(*buf.as_array().get_unchecked(j))
}; };
j += 1; j += 1;
} }
@ -344,28 +467,30 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
} }
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0); decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
} else if false { } else if VECTORED_B {
let (ascii_pre, ascii_simd, ascii_post) = unsafe { ascii.align_to::<Simd<u8, DIGIT_BATCH_SIZE>>() }; let (ascii_pre, ascii_simd, ascii_post) =
unsafe { ascii.align_to::<Simd<u8, DIGIT_BATCH_SIZE>>() };
assert_eq!(ascii_pre.len() % 2, 0); debug_assert_eq!(ascii_pre.len() % 2, 0);
assert_eq!(ascii_post.len() % 2, 0); debug_assert_eq!(ascii_post.len() % 2, 0);
let mut i = 0; let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii_pre, bytes, 0); decode_hex_bytes_non_vectored!(i, ascii_pre, bytes, 0);
let mut i = 0; let mut i = 0;
while i < ascii_simd.len() { while i < ascii_simd.len() {
let buf = hex_byte_simd(unsafe { *ascii_simd.get_unchecked(i) }); // this to_array and any subsequent from_array should be eliminated anyway
if buf > Simd::splat(u8::MAX as u16) { let buf = HBD::decode_simd(unsafe { ascii_simd.get_unchecked(i) }.to_array());
println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE); let buf = match buf {
return false; Some(buf) => buf,
} None => return false,
let buf = buf.cast::<u8>(); };
let mut j = 0; let mut j = 0;
let k = ascii_pre.len() + i * DIGIT_BATCH_SIZE; let k = ascii_pre.len() + i * DIGIT_BATCH_SIZE;
while j < DIGIT_BATCH_SIZE { while j < DIGIT_BATCH_SIZE {
unsafe { unsafe {
*bytes.get_unchecked_mut(k + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j)) *bytes.get_unchecked_mut(k + j) =
MaybeUninit::new(*buf.as_array().get_unchecked(j))
}; };
j += 1; j += 1;
} }
@ -375,6 +500,28 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
let mut i = 0; let mut i = 0;
let k = ascii.len() - ascii_post.len(); let k = ascii.len() - ascii_post.len();
decode_hex_bytes_non_vectored!(i, ascii_post, bytes, k); decode_hex_bytes_non_vectored!(i, ascii_post, bytes, k);
} else if VECTORED_C {
let mut i = 0;
while i < align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let buf = HBD::decode_simd(unsafe {
*(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE])
});
let buf = match buf {
Some(buf) => buf,
None => return false,
};
let mut j = 0;
while j < DIGIT_BATCH_SIZE {
unsafe {
*bytes.get_unchecked_mut((i >> 1) + j) =
MaybeUninit::new(*buf.as_array().get_unchecked(j))
};
j += 1;
}
i += DIGIT_BATCH_SIZE;
}
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
} else { } else {
let mut i = 0; let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0); decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
@ -391,14 +538,17 @@ pub const fn hex_bytes_sized_const<const N: usize>(ascii: &[u8; N * 2]) -> Optio
} else { } else {
let mut bytes = MaybeUninit::<u8>::uninit_array::<N>(); let mut bytes = MaybeUninit::<u8>::uninit_array::<N>();
let mut i = 0; let mut i = 0;
while i < bytes.len() { while i < N * 2 {
match hex_byte(unsafe { *ascii.get_unchecked(i << 1) }, unsafe { if i >> 1 >= bytes.len() {
*ascii.get_unchecked((i << 1) + 1) unsafe { std::hint::unreachable_unchecked() };
}
match hex_byte(unsafe { *ascii.get_unchecked(i) }, unsafe {
*ascii.get_unchecked(i + 1)
}) { }) {
Some(b) => bytes[i] = MaybeUninit::new(b), Some(b) => bytes[i >> 1] = MaybeUninit::new(b),
None => return None, None => return None,
} }
i += 1; i += 2;
} }
Some(unsafe { MaybeUninit::array_assume_init(bytes) }) Some(unsafe { MaybeUninit::array_assume_init(bytes) })
} }
@ -418,12 +568,12 @@ pub fn hex_bytes_sized<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
} }
} }
#[inline]
pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[u8; N]>> { pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[u8; N]>> {
if N == 0 { if N == 0 {
Some(Box::new([0u8; N])) Some(Box::new([0u8; N]))
} else { } else {
let mut bytes: Box<[MaybeUninit<u8>; N]> = let mut bytes = unsafe { Box::<[MaybeUninit<u8>; N]>::new_uninit().assume_init() };
unsafe { Box::<[MaybeUninit<u8>; N]>::new_uninit().assume_init() };
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) { if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8; N]) }) Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8; N]) })
} else { } else {
@ -432,19 +582,21 @@ pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[
} }
} }
#[inline]
pub fn hex_bytes_dyn_unsafe(ascii: &[u8]) -> Option<Box<[u8]>> { pub fn hex_bytes_dyn_unsafe(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1; let len = ascii.len() >> 1;
if len << 1 != ascii.len() { if len << 1 != ascii.len() {
return None; return None;
} }
let mut bytes = Box::<[u8]>::new_uninit_slice(len); let mut bytes = Box::new_uninit_slice(len);
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) { if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) }) Some(unsafe { Box::<[_]>::assume_init(bytes) })
} else { } else {
None None
} }
} }
#[inline]
pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> { pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1; let len = ascii.len() >> 1;
if len << 1 != ascii.len() { if len << 1 != ascii.len() {
@ -462,9 +614,10 @@ pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> {
return None; return None;
} }
} }
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) }) Some(unsafe { Box::<[_]>::assume_init(bytes) })
} }
#[inline]
pub fn hex_bytes_dyn_unsafe_iter_niched(ascii: &[u8]) -> Option<Box<[u8]>> { pub fn hex_bytes_dyn_unsafe_iter_niched(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1; let len = ascii.len() >> 1;
if len << 1 != ascii.len() { if len << 1 != ascii.len() {
@ -473,18 +626,18 @@ pub fn hex_bytes_dyn_unsafe_iter_niched(ascii: &[u8]) -> Option<Box<[u8]>> {
let mut bytes = Box::<[u8]>::new_uninit_slice(len); let mut bytes = Box::<[u8]>::new_uninit_slice(len);
for (i, b) in ascii for (i, b) in ascii
.array_chunks::<2>() .array_chunks::<2>()
.map(|[msb, lsb]| hex_byte_niched(*msb, *lsb)) .map(HBD::decode_packed)
.enumerate() .enumerate()
{ {
if b & 0xff_00 == 0 { if b & WIDE_INVALID_BIT != 0 {
unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b as u8) };
} else {
return None; return None;
} }
unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b as u8) };
} }
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) }) Some(unsafe { Box::<[_]>::assume_init(bytes) })
} }
#[inline]
pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> { pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> {
let iter = ascii.array_chunks::<2>(); let iter = ascii.array_chunks::<2>();
if iter.remainder().len() != 0 { if iter.remainder().len() != 0 {
@ -495,13 +648,20 @@ pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> {
.map(|v| v.into_boxed_slice()) .map(|v| v.into_boxed_slice())
} }
#[inline]
pub fn hex_bytes_dyn_niched(ascii: &[u8]) -> Option<Box<[u8]>> { pub fn hex_bytes_dyn_niched(ascii: &[u8]) -> Option<Box<[u8]>> {
let iter = ascii.array_chunks::<2>(); let iter = ascii.array_chunks::<2>();
if iter.remainder().len() != 0 { if iter.remainder().len() != 0 {
return None; return None;
} }
iter.map(|[msb, lsb]| hex_byte_niched(*msb, *lsb)) iter.map(HBD::decode_packed)
.map(|b| if b & 0xff_00 == 0 { Some(b as u8) } else { None }) .map(|b| {
if b & WIDE_INVALID_BIT != 0 {
None
} else {
Some(b as u8)
}
})
.collect::<Option<Vec<u8>>>() .collect::<Option<Vec<u8>>>()
.map(|v| v.into_boxed_slice()) .map(|v| v.into_boxed_slice())
} }
@ -522,8 +682,21 @@ mod test {
} }
const SAMPLES: &[Sample] = &[ const SAMPLES: &[Sample] = &[
Sample { bytes: BYTES, hex_bytes: HEX_BYTES }, Sample {
Sample { bytes: LONG_BYTES, hex_bytes: LONG_HEX_BYTES }, bytes: BYTES,
hex_bytes: HEX_BYTES,
},
Sample {
bytes: LONG_BYTES,
hex_bytes: LONG_HEX_BYTES,
},
];
const INVALID_SAMPLES: &[&[u8]] = &[
b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722G",
b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E7365637465747572GE",
b"446F6C6F72756D2064697374696E6374696G20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
b"446F6C6F72756D2064697374696E637469GF20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
]; ];
#[test] #[test]
@ -547,8 +720,10 @@ mod test {
#[test] #[test]
fn test_hex_digit_simd() { fn test_hex_digit_simd() {
const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[ const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E',
'F',
// '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
]; ];
let mut set8 = [0u8; DIGIT_BATCH_SIZE]; let mut set8 = [0u8; DIGIT_BATCH_SIZE];
@ -559,7 +734,10 @@ mod test {
for i in 0..(DIGIT_BATCH_SIZE) { for i in 0..(DIGIT_BATCH_SIZE) {
sete[i] = (i as u8) % 16; sete[i] = (i as u8) % 16;
} }
assert_eq!(hex_digits(Simd::from_array(set8)), Simd::from_array(sete)); assert_eq!(
hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(set8)),
Simd::from_array(sete)
);
} }
#[test] #[test]
@ -576,7 +754,12 @@ mod test {
for (hb, b) in HEX_BYTES_VALID { for (hb, b) in HEX_BYTES_VALID {
assert_eq!(hex_byte(hb[0], hb[1]), Some(*b)); assert_eq!(hex_byte(hb[0], hb[1]), Some(*b));
assert_eq!(hex_byte_niched(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderA::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderB::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderA::decode_packed(hb), *b as u16);
assert_eq!(HexByteDecoderB::decode_packed(hb), *b as u16);
} }
const HEX_BYTES_INVALID: &[[u8; 2]] = &[ const HEX_BYTES_INVALID: &[[u8; 2]] = &[
@ -591,39 +774,47 @@ mod test {
for hb in HEX_BYTES_INVALID { for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None); assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & 0xff_00 != 0); assert_ne!(
HexByteDecoderA::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(
HexByteDecoderB::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(HexByteDecoderA::decode_packed(hb) & WIDE_INVALID_BIT, 0);
assert_ne!(HexByteDecoderB::decode_packed(hb) & WIDE_INVALID_BIT, 0);
} }
} }
#[test] #[test]
fn test_hex_byte_simd() { fn test_hex_byte_simd() {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [ const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11",
*b"ef", *b"fe", *b"0f", *b"f0",
*b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", // *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
]; ];
const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [ const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f,
0xf0,
0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, // 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
]; ];
let hex_bytes = Simd::from_array(conv::u8x2_to_u8(HEX_BYTES_VALID)); let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID);
let bytes = BYTES_VALID.map(|b| b as u16); let bytes = Simd::from_array(BYTES_VALID);
let bytes = Simd::from_array(bytes);
println!("hex_bytes: {HEX_BYTES_VALID:02x?}"); println!("hex_bytes: {HEX_BYTES_VALID:02x?}");
println!("hex_bytes: {hex_bytes:02x?}"); println!("hex_bytes: {hex_bytes:02x?}");
println!("bytes: {BYTES_VALID:02x?}"); println!("bytes: {BYTES_VALID:02x?}");
println!("bytes: {bytes:04x?}"); println!("bytes: {bytes:04x?}");
assert_eq!(hex_byte_simd(hex_bytes), bytes); assert_eq!(HexByteDecoderA::decode_simd(hex_bytes), Some(bytes));
assert_eq!(HexByteDecoderB::decode_simd(hex_bytes), Some(bytes));
/*const HEX_BYTES_INVALID: &[[u8; 2]] = &[ /*const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8], ['f' as u8, 'g' as u8],
@ -637,14 +828,27 @@ mod test {
for hb in HEX_BYTES_INVALID { for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None); assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & 0xff_00 != 0); assert!(hex_byte_niched(hb[0], hb[1]) & WIDE_INVALID_BIT != 0);
}*/ }*/
} }
fn test_f(f: fn(&[u8]) -> Option<Box<[u8]>>) { fn test_f(f: fn(&[u8]) -> Option<Box<[u8]>>) {
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() { for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
let result = f(hex_bytes); let result = f(hex_bytes);
assert_eq!(Some(*bytes), result.as_ref().map(Box::as_ref), "Sample {i} did not decode correctly"); assert_eq!(
Some(*bytes),
result.as_ref().map(Box::as_ref),
"Sample {i} did not decode correctly"
);
}
for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() {
let result = f(hex_bytes);
assert_eq!(
None,
result.as_ref().map(Box::as_ref),
"Sample {i} did not decode correctly"
);
} }
} }

89
src/simd.rs Normal file
View File

@ -0,0 +1,89 @@
use std::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::util::cast;
pub trait SimdSplatZero<const LANES: usize> {
fn splat_zero() -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount;
}
pub trait SimdSplatN<const LANES: usize> {
fn splat_n(n: u8) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount;
}
pub struct SimdOps;
const W_128: usize = 128 / 8;
const W_256: usize = 256 / 8;
const W_512: usize = 512 / 8;
macro_rules! specialized {
($LANES:ident, $trait:ident {
$(
fn $name:ident($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)*
}) => {
impl<const $LANES: usize> $trait<$LANES> for SimdOps {
$(
#[inline(always)]
fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
}
}
)*
}
};
}
specialized!(LANES, SimdSplatN {
fn splat_n(n: u8) -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
_ => Simd::splat(n),
}
});
specialized!(LANES, SimdSplatZero {
fn splat_zero() -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => <Self as SimdSplatN<LANES>>::splat_n(0),
}
});
#[inline(always)]
pub fn splat_0<const LANES: usize>() -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
<SimdOps as SimdSplatZero<LANES>>::splat_zero()
}
#[inline(always)]
pub fn splat_n<const LANES: usize>(n: u8) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
<SimdOps as SimdSplatN<LANES>>::splat_n(n)
}

55
src/util.rs Normal file
View File

@ -0,0 +1,55 @@
#[doc(hidden)]
#[macro_export]
macro_rules! __defer_impl {
(
=> $impl:ident;
$( fn $name:ident($( $pname:ident: $pty:ty ),*) -> $rty:ty; )*
) => {
$(
#[inline(always)]
fn $name($( $pname: $pty ),*) -> $rty {
<$impl>::$name($( $pname ),*)
}
)*
};
}
pub use __defer_impl as defer_impl;
#[inline(always)]
#[cold]
pub fn cold() {}
#[inline]
pub fn likely(b: bool) -> bool {
if !b {
cold()
}
b
}
#[inline]
pub fn unlikely(b: bool) -> bool {
if b {
cold()
}
b
}
/// Like transmute, but implemented via a union so that we can use it in situations where
/// transmute's "safety" restrictions are too strict and uninformed (i.e. we can prove it is safe).
#[inline(always)]
pub unsafe fn cast<A, B>(a: A) -> B {
union Cast<A, B> {
a: std::mem::ManuallyDrop<A>,
b: std::mem::ManuallyDrop<B>,
}
std::mem::ManuallyDrop::into_inner(
Cast {
a: std::mem::ManuallyDrop::new(a),
}
.b,
)
}