This commit is contained in:
Michael Pfaff 2022-10-24 15:14:10 -04:00
parent 964fc7073c
commit 0942be2634
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
2 changed files with 394 additions and 49 deletions

View File

@ -1,6 +1,8 @@
#![feature(generic_const_exprs)]
#![feature(new_uninit)]
#![feature(portable_simd)]
use std::mem::MaybeUninit;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
@ -17,7 +19,7 @@ macro_rules! name {
std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str())
};
($group:literal, $f:literal) => {
concat!($group, " - ", $f)
concat!("[", $group, "] - ", $f)
};
}
@ -141,12 +143,12 @@ fn benchmark(name: &str, bytes: &[u8], c: &mut Criterion) {
pub fn bench_16(c: &mut Criterion) {
test_sized::<{ ASCII_BYTES.len() }, false>(HEX_BYTES, ASCII_BYTES);
benchmark_sized::<{ ASCII_BYTES.len() }, false>("[16]", HEX_BYTES, c);
benchmark_sized::<{ ASCII_BYTES.len() }, false>("16", HEX_BYTES, c);
}
pub fn bench_256(c: &mut Criterion) {
test_sized::<{ ASCII_BYTES_LONG.len() }, false>(HEX_BYTES_LONG, ASCII_BYTES_LONG);
benchmark_sized::<{ ASCII_BYTES_LONG.len() }, false>("[256]", HEX_BYTES_LONG, c);
benchmark_sized::<{ ASCII_BYTES_LONG.len() }, false>("256", HEX_BYTES_LONG, c);
}
const fn __make_hex_chars() -> [u8; 16] {
@ -227,7 +229,7 @@ impl<'a> std::fmt::Display for DisplayAsHexDigits<'a> {
match hex_digit(*b) {
d @ 0..=9 => f.write_char(('0' as u8 + d) as char),
d @ 10..=15 => f.write_char(('a' as u8 + d - 10) as char),
_ => write!(f, "0x{:x}", b),
_ => write!(f, "0x{:02x}", b),
}?;
f.write_char(' ')?;
}
@ -256,19 +258,65 @@ pub fn bench_1_6m(c: &mut Criterion) {
}
};
test_sized::<LEN, true>(&hex_bytes, bytes.as_ref().try_into().unwrap());
benchmark_sized::<LEN, true>("[1.6m]", &hex_bytes, c);
benchmark_sized::<LEN, true>("1.6m", &hex_bytes, c);
}
pub fn bench_hex_digit(c: &mut Criterion) {
let digit = ['5' as u8, 'b' as u8];
pub fn bench_micro_hex_byte(c: &mut Criterion) {
use std::simd::Simd;
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
let hex_bytes = Simd::from_array(conv::u8x2_to_u8(black_box(HEX_BYTES_VALID)));
c.bench_function(name!("micro", "hex_byte"), |b| {
b.iter(|| {
for b in black_box(HEX_BYTES_VALID) {
hex_byte(b[0], b[1]);
}
})
});
c.bench_function(name!("micro", "hex_byte_niched"), |b| {
b.iter(|| {
for b in black_box(HEX_BYTES_VALID) {
hex_byte_niched(b[0], b[1]);
}
})
});
c.bench_function(name!("micro", "hex_byte_simd"), |b| {
b.iter(|| hex_byte_simd(hex_bytes))
});
}
pub fn bench_nano_hex_byte(c: &mut Criterion) {
let digit = black_box(['5' as u8, 'b' as u8]);
c.bench_function(name!("nano", "hex_byte"), |b| {
b.iter(|| hex_byte(digit[0], digit[1]))
});
c.bench_function(name!("nano", "hex_byte_niched"), |b| {
b.iter(|| hex_byte_niched(digit[0], digit[1]))
});
c.bench_function(name!("nano", "hex_byte +bb"), |b| {
b.iter(|| hex_byte(black_box(digit[0]), black_box(digit[1])))
});
c.bench_function(name!("micro", "hex_byte_niched"), |b| {
c.bench_function(name!("nano", "hex_byte_niched +bb"), |b| {
b.iter(|| hex_byte_niched(black_box(digit[0]), black_box(digit[1])))
});
}
criterion_group!(decode_benches, bench_16, bench_256, bench_1_6m);
criterion_group!(micro_benches, bench_hex_digit);
criterion_main!(decode_benches, micro_benches);
criterion_group!(micro_benches, bench_micro_hex_byte);
criterion_group!(nano_benches, bench_nano_hex_byte);
criterion_main!(decode_benches, micro_benches, nano_benches);

View File

@ -10,11 +10,72 @@
#![feature(const_maybe_uninit_uninit_array)]
#![feature(new_uninit)]
use std::fmt;
use std::mem::MaybeUninit;
#![feature(portable_simd)]
#[inline]
const fn __make_ascii_digit_table() -> [u8; 256] {
use std::mem::MaybeUninit;
use std::simd::*;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = 512 / 16;
/// The batch size used for the hex digits.
// use the maximum batch size that would be supported by AVX-512
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
let mut mask = [false; N];
let mut i = 0;
if first_bias {
while i < N / 2 {
mask[i * 2] = true;
i += 1;
}
} else {
while i < N / 2 {
mask[i * 2 + 1] = true;
i += 1;
}
}
mask
}
const fn msb_lsb_indices<const N: usize>() -> [usize; N] {
if N % 2 != 0 {
panic!("Illegal N");
}
let mut indices = [0; N];
let mut i = 0;
while i < N / 2 {
indices[i] = i * 2;
indices[N / 2 + i] = i * 2 + 1;
i += 1;
}
indices
}
const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] {
let mut indices = [0; N];
let mut i = 0;
if first_bias {
while i < N {
indices[i] = i * 2;
i += 1;
}
} else {
while i < N {
indices[i] = i * 2 + 1;
i += 1;
}
}
indices
}
const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true);
const LSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(false);
const ASCII_DIGITS: [u8; 256] = {
let mut digits = [0u8; 256];
let mut i = u8::MIN;
while i < u8::MAX {
@ -35,9 +96,7 @@ const fn __make_ascii_digit_table() -> [u8; 256] {
i += 1;
}
digits
}
const ASCII_DIGITS: [u8; 256] = __make_ascii_digit_table();
};
/// Returns 255 if invalid. Based on `char.to_digit()` in the stdlib.
#[inline]
@ -68,6 +127,11 @@ pub const fn hex_digit(ascii: u8) -> u8 {
// return 255;
}
#[inline(always)]
pub fn hex_digits<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES> where LaneCount<LANES>: SupportedLaneCount {
unsafe { Simd::gather_select_unchecked(&ASCII_DIGITS, Mask::splat(true), ascii.cast(), Simd::splat(0)) }
}
/// Parses an ascii hex byte.
#[inline]
pub const fn hex_byte(msb: u8, lsb: u8) -> Option<u8> {
@ -106,6 +170,114 @@ pub const fn hex_byte_packed_niched([msb, lsb]: &[u8; 2]) -> u16 {
(msb << 4) | (lsb & 0xf) | ((lsb & 0xf0) << 8)
}
#[inline]
pub fn hex_byte_simd(hex_bytes: Simd<u8, DIGIT_BATCH_SIZE>) -> Simd<u16, WIDE_BATCH_SIZE> {
let hex_digits = hex_digits(hex_bytes);
//println!("hex_digits: {hex_digits:04x?}");
//println!("MSB_INDICES: {MSB_INDICES:02?}");
//println!("LSB_INDICES: {LSB_INDICES:02?}");
let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
//println!("msb: {msb:04x?}");
//println!("lsb: {lsb:04x?}");
let msb = msb.cast::<u16>();
let lsb = lsb.cast::<u16>();
msb << Simd::splat(4) | lsb | ((lsb & Simd::splat(0xf0)) << Simd::splat(8))
}
pub mod conv {
use std::simd::{Simd, LaneCount, SupportedLaneCount};
/*trait Size {
const N: usize;
}
macro_rules! size_impl {
($ident:ident($size:expr)) => {
struct $ident;
impl Size for $ident {
const N: usize = $size;
}
};
($ident:ident<$size:ty>) => {
size_impl!($ident(std::mem::size_of::<$size>()));
};
}
struct SizeMul<const N: usize, T>(std::marker::PhantomData<T>);
impl<const N: usize, T: Size> Size for SizeMul<N, T> {
const N: usize = T::N * N;
}
size_impl!(SizeU8<u8>);
size_impl!(SizeU16<u16>);
size_impl!(SizeU32<u32>);
size_impl!(SizeU64<u64>);
trait SizeOf {
type Size: Size;
//const SIZE: usize;
}
//impl<T> SizeOf for T {
// const SIZE: usize = std::mem::size_of::<T>();
//}
macro_rules! size_of_impl {
($type:ty = $size:ident) => {
impl SizeOf for $type {
type Size = $size;
}
};
}
size_of_impl!(u8 = SizeU8);
size_of_impl!([u8; 2] = SizeU16);
size_of_impl!(u16 = SizeU16);
size_of_impl!(u32 = SizeU32);
size_of_impl!(u64 = SizeU64);*/
#[allow(non_camel_case_types, non_snake_case)]
union u8_u16<const N_u16: usize> where [u8; N_u16 * 2]: {
u8: [u8; N_u16 * 2],
u16: [u16; N_u16],
}
#[allow(non_camel_case_types, non_snake_case)]
union u8x2_u8<const N_u8x2: usize> where [u8; N_u8x2 * 2]: {
u8x2: [[u8; 2]; N_u8x2],
u8: [u8; N_u8x2 * 2],
}
#[allow(non_camel_case_types, non_snake_case)]
union SimdU8_SimdU16<const N_U16: usize> where LaneCount<{ N_U16 * 2 }>: SupportedLaneCount, LaneCount<N_U16>: SupportedLaneCount {
SimdU8: Simd<u8, { N_U16 * 2 }>,
SimdU16: Simd<u16, N_U16>,
}
#[inline(always)]
pub const fn u8_to_u16<const N_OUT: usize>(a: [u8; N_OUT * 2]) -> [u16; N_OUT] {
unsafe { u8_u16 { u8: a }.u16 }
}
#[inline(always)]
pub const fn u8x2_to_u8<const N_IN: usize>(a: [[u8; 2]; N_IN]) -> [u8; N_IN * 2] {
unsafe { u8x2_u8 { u8x2: a }.u8 }
}
#[inline(always)]
pub const fn simdu8_to_simdu16<const N_OUT: usize>(a: Simd<u8, { N_OUT * 2 }>) -> Simd<u16, N_OUT>
where
LaneCount<{ N_OUT * 2 }>: SupportedLaneCount,
LaneCount<N_OUT>: SupportedLaneCount
{
unsafe { SimdU8_SimdU16 { SimdU8: a }.SimdU16 }
}
}
#[inline(always)]
const fn align_down_to<const N: usize>(n: usize) -> usize {
let shift = match N.checked_ilog2() {
@ -124,50 +296,88 @@ const fn align_up_to<const N: usize>(n: usize) -> usize {
return (n + (N - 1)) >> shift << shift;
}
macro_rules! decode_hex_bytes_non_vectored {
($i:ident, $ascii:ident, $bytes:ident, $o:expr) => {{
while $i < $ascii.len() {
match unsafe { hex_byte(*$ascii.get_unchecked($i), *$ascii.get_unchecked($i + 1)) } {
Some(b) => unsafe { *$bytes.get_unchecked_mut($o + ($i >> 1)) = MaybeUninit::new(b) },
None => {
println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
return false
},
}
$i += 2;
}
}};
}
#[inline(always)]
fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bool {
debug_assert_eq!(ascii.len() >> 1 << 1, ascii.len(), "len of ascii is not a multiple of 2");
debug_assert_eq!(ascii.len() >> 1, bytes.len(), "len of ascii is not twice that of bytes");
let mut i = 0;
// use the maximum batch size that would be supported by AVX-512
const BATCH_SIZE: usize = 512 / 16;
const VECTORED: bool = true;
if VECTORED {
while i < align_down_to::<BATCH_SIZE>(bytes.len()) {
let mut buf = MaybeUninit::<u16>::uninit_array::<BATCH_SIZE>();
union Aligned {
bytes: [u8; DIGIT_BATCH_SIZE],
simd: Simd<u8, DIGIT_BATCH_SIZE>,
}
let mut i = 0;
while i < align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let slice = unsafe { &*(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE]) };
let aligned = Aligned { bytes: *slice };
let buf = hex_byte_simd(unsafe { aligned.simd });
if buf > Simd::splat(u8::MAX as u16) {
println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE);
return false;
}
let buf = buf.cast::<u8>();
let mut j = 0;
while j < buf.len() {
while j < DIGIT_BATCH_SIZE {
unsafe {
*buf.get_unchecked_mut(j) = MaybeUninit::new(hex_byte_packed_niched(
&*(ascii.as_ptr().add((i + j) << 1) as *const [u8; 2]),
))
*bytes.get_unchecked_mut((i >> 1) + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j))
};
j += 1;
}
let buf = unsafe { MaybeUninit::array_assume_init(buf) };
for x in buf.iter() {
if *x > u8::MAX as u16 {
return false;
}
i += DIGIT_BATCH_SIZE;
}
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
} else if false {
let (ascii_pre, ascii_simd, ascii_post) = unsafe { ascii.align_to::<Simd<u8, DIGIT_BATCH_SIZE>>() };
assert_eq!(ascii_pre.len() % 2, 0);
assert_eq!(ascii_post.len() % 2, 0);
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii_pre, bytes, 0);
let mut i = 0;
while i < ascii_simd.len() {
let buf = hex_byte_simd(unsafe { *ascii_simd.get_unchecked(i) });
if buf > Simd::splat(u8::MAX as u16) {
println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE);
return false;
}
let buf = buf.cast::<u8>();
let mut j = 0;
while j < buf.len() {
let k = ascii_pre.len() + i * DIGIT_BATCH_SIZE;
while j < DIGIT_BATCH_SIZE {
unsafe {
*bytes.get_unchecked_mut(i + j) = MaybeUninit::new(*buf.get_unchecked(j) as u8)
*bytes.get_unchecked_mut(k + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j))
};
j += 1;
}
i += buf.len();
i += 1;
}
}
while i < bytes.len() {
match hex_byte(unsafe { *ascii.get_unchecked(i << 1) }, unsafe {
*ascii.get_unchecked((i << 1) + 1)
}) {
Some(b) => unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b) },
None => return false,
}
i += 1;
let mut i = 0;
let k = ascii.len() - ascii_post.len();
decode_hex_bytes_non_vectored!(i, ascii_post, bytes, k);
} else {
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
}
true
}
@ -303,6 +513,19 @@ mod test {
const BYTES: &[u8] = b"Donald J. Trump!";
const HEX_BYTES: &[u8] = b"446F6E616C64204A2E205472756D7021";
const LONG_BYTES: &[u8] = b"Dolorum distinctio ut earum quidem distinctio necessitatibus quam. Sit praesentium facere perspiciatis iure aut sunt et et. Adipisci enim rerum illum et officia nisi recusandae. Vitae doloribus ut quia ea unde consequuntur quae illum. Id eius harum est. Inventore ipsum ut sit ut vero consectetur.";
const LONG_HEX_BYTES: &[u8] = b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E";
struct Sample {
bytes: &'static [u8],
hex_bytes: &'static [u8],
}
const SAMPLES: &[Sample] = &[
Sample { bytes: BYTES, hex_bytes: HEX_BYTES },
Sample { bytes: LONG_BYTES, hex_bytes: LONG_HEX_BYTES },
];
#[test]
fn test_hex_digit() {
const HEX_DIGITS_LOWER: &[char; 16] = &[
@ -321,6 +544,24 @@ mod test {
}
}
#[test]
fn test_hex_digit_simd() {
const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
];
let mut set8 = [0u8; DIGIT_BATCH_SIZE];
for (c, b) in HEX_DIGITS.iter().zip(set8.iter_mut()) {
*b = *c as u8;
}
let mut sete = [0u8; DIGIT_BATCH_SIZE];
for i in 0..(DIGIT_BATCH_SIZE) {
sete[i] = (i as u8) % 16;
}
assert_eq!(hex_digits(Simd::from_array(set8)), Simd::from_array(sete));
}
#[test]
fn test_hex_byte() {
const HEX_BYTES_VALID: &[([u8; 2], u8)] = &[
@ -355,14 +596,70 @@ mod test {
}
#[test]
fn test_non_niched() {
let result = hex_bytes_dyn(HEX_BYTES);
assert_eq!(Some(BYTES), result.as_ref().map(Box::as_ref));
fn test_hex_byte_simd() {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
];
let hex_bytes = Simd::from_array(conv::u8x2_to_u8(HEX_BYTES_VALID));
let bytes = BYTES_VALID.map(|b| b as u16);
let bytes = Simd::from_array(bytes);
println!("hex_bytes: {HEX_BYTES_VALID:02x?}");
println!("hex_bytes: {hex_bytes:02x?}");
println!("bytes: {BYTES_VALID:02x?}");
println!("bytes: {bytes:04x?}");
assert_eq!(hex_byte_simd(hex_bytes), bytes);
/*const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['1' as u8, 'g' as u8],
['e' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
];
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & 0xff_00 != 0);
}*/
}
fn test_f(f: fn(&[u8]) -> Option<Box<[u8]>>) {
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
let result = f(hex_bytes);
assert_eq!(Some(*bytes), result.as_ref().map(Box::as_ref), "Sample {i} did not decode correctly");
}
}
#[test]
fn test_niched() {
let result = hex_bytes_dyn_niched(HEX_BYTES);
assert_eq!(Some(BYTES), result.as_ref().map(Box::as_ref));
fn test_dyn_iter_option() {
test_f(hex_bytes_dyn);
}
#[test]
fn test_dyn_iter_u16() {
test_f(hex_bytes_dyn_niched);
}
#[test]
fn test_dyn_unsafe() {
test_f(hex_bytes_dyn_unsafe);
}
}