This commit is contained in:
Michael Pfaff 2022-10-30 15:35:17 -04:00
parent 0942be2634
commit 379fa71d78
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
5 changed files with 574 additions and 163 deletions

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
[dev-dependencies]
criterion = "0.3"
criterion = { version = "0.4", features = [ "real_blackbox" ] }
rand = "0.8.5"
[[bench]]

View File

@ -1,6 +1,5 @@
#![feature(generic_const_exprs)]
#![feature(new_uninit)]
#![feature(portable_simd)]
use std::mem::MaybeUninit;
@ -15,15 +14,21 @@ const ASCII_BYTES_LONG: &[u8; 256] = b"Donald J. Trump!Donald J. Trump!Donald J.
const HEX_BYTES_LONG: &[u8; ASCII_BYTES_LONG.len() * 2] = b"446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021";
macro_rules! name {
($group:ident, $f:literal) => {
std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str())
};
($group:literal, $f:literal) => {
concat!("[", $group, "] - ", $f)
};
($group:expr, $f:literal) => {
std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str())
};
($group:expr, $f:expr) => {
std::boxed::Box::leak(format!(name!("{}", "{}"), $group, $f).into_boxed_str())
};
($group:literal, $f:expr) => {
std::boxed::Box::leak(format!(name!($group, "{}"), $f).into_boxed_str())
};
}
//#[track_caller]
#[track_caller]
fn test_sized<const N: usize, const HEAP_ONLY: bool>(hex_bytes: &[u8; N * 2], bytes: &[u8; N])
where
[(); N * 2]:,
@ -50,7 +55,7 @@ where
);
}
//#[track_caller]
#[track_caller]
#[inline]
fn test(hex_bytes: &[u8], bytes: &[u8]) {
assert_eq!(hex_bytes.len(), bytes.len() * 2);
@ -261,62 +266,120 @@ pub fn bench_1_6m(c: &mut Criterion) {
benchmark_sized::<LEN, true>("1.6m", &hex_bytes, c);
}
pub fn bench_micro_hex_byte(c: &mut Criterion) {
pub fn bench_micro_hex_digit(c: &mut Criterion) {
use std::simd::Simd;
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
const HEX_DIGITS_VALID: [u8; DIGIT_BATCH_SIZE] = [
0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4, 0xf, 0xf,
0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3,
0x4,
// 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4,
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
// 0xf, 0xf, 0x0, 0x0, 0x1, 0x1, 0xe, 0xf, 0xf, 0xe, 0x0, 0xf, 0xf, 0x0, 0x3, 0x4,
];
let hex_bytes = Simd::from_array(conv::u8x2_to_u8(black_box(HEX_BYTES_VALID)));
let hex_digits = Simd::from_array(black_box(HEX_DIGITS_VALID));
c.bench_function(name!("micro", "hex_digit"), |b| {
b.iter(|| {
for b in black_box(HEX_DIGITS_VALID) {
black_box(hex_digit(b));
}
})
});
c.bench_function(name!("micro", "hex_digit_simd"), |b| {
b.iter(|| hex_digit_simd::<DIGIT_BATCH_SIZE>(hex_digits))
});
}
pub fn bench_micro_hex_byte(c: &mut Criterion) {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11",
*b"ef", *b"fe", *b"0f", *b"f0",
*b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
fn bench_decoder<T: HexByteDecoder + HexByteSimdDecoder>(c: &mut Criterion, name: &str) {
let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID);
c.bench_function(name!("micro", format!("{name}::decode_packed")), |b| {
b.iter(|| {
for b in black_box(HEX_BYTES_VALID) {
black_box(T::decode_packed(&b));
}
})
});
c.bench_function(name!("micro", format!("{name}::decode_unpacked")), |b| {
b.iter(|| {
for [hi, lo] in black_box(HEX_BYTES_VALID) {
black_box(T::decode_unpacked(hi, lo));
}
})
});
c.bench_function(name!("micro", format!("{name}::decode_simd")), |b| {
b.iter(|| T::decode_simd(black_box(hex_bytes)))
});
}
c.bench_function(name!("micro", "hex_byte"), |b| {
b.iter(|| {
for b in black_box(HEX_BYTES_VALID) {
hex_byte(b[0], b[1]);
black_box(hex_byte(b[0], b[1]));
}
})
});
c.bench_function(name!("micro", "hex_byte_niched"), |b| {
b.iter(|| {
for b in black_box(HEX_BYTES_VALID) {
hex_byte_niched(b[0], b[1]);
}
})
});
bench_decoder::<HexByteDecoderA>(c, stringify!(HexByteDecoderA));
bench_decoder::<HexByteDecoderB>(c, stringify!(HexByteDecoderB));
}
c.bench_function(name!("micro", "hex_byte_simd"), |b| {
b.iter(|| hex_byte_simd(hex_bytes))
pub fn bench_nano_hex_digit(c: &mut Criterion) {
let digit = black_box('5' as u8);
c.bench_function(name!("nano", "hex_digit"), |b| b.iter(|| hex_digit(digit)));
c.bench_function(name!("nano", "hex_digit +bb"), |b| {
b.iter(|| hex_digit(black_box(digit)))
});
}
pub fn bench_nano_hex_byte(c: &mut Criterion) {
let digit = black_box(['5' as u8, 'b' as u8]);
const DIGITS: [u8; 2] = ['5' as u8, 'b' as u8];
let digit = black_box(DIGITS);
c.bench_function(name!("nano", "hex_byte"), |b| {
b.iter(|| hex_byte(digit[0], digit[1]))
});
c.bench_function(name!("nano", "hex_byte_niched"), |b| {
b.iter(|| hex_byte_niched(digit[0], digit[1]))
});
fn bench_decoder<T: HexByteDecoder + HexByteSimdDecoder>(c: &mut Criterion, name: &str) {
let digit = black_box(DIGITS);
c.bench_function(name!("nano", "hex_byte +bb"), |b| {
b.iter(|| hex_byte(black_box(digit[0]), black_box(digit[1])))
});
c.bench_function(name!("nano", format!("{name}::decode_packed")), |b| {
b.iter(|| {
black_box(T::decode_packed(&digit));
})
});
c.bench_function(name!("nano", "hex_byte_niched +bb"), |b| {
b.iter(|| hex_byte_niched(black_box(digit[0]), black_box(digit[1])))
});
c.bench_function(name!("nano", format!("{name}::decode_unpacked")), |b| {
b.iter(|| {
black_box(T::decode_unpacked(
black_box(DIGITS[0]),
black_box(DIGITS[1]),
));
})
});
}
bench_decoder::<HexByteDecoderA>(c, stringify!(HexByteDecoderA));
bench_decoder::<HexByteDecoderB>(c, stringify!(HexByteDecoderB));
}
criterion_group!(decode_benches, bench_16, bench_256, bench_1_6m);
criterion_group!(micro_benches, bench_micro_hex_byte);
criterion_group!(nano_benches, bench_nano_hex_byte);
criterion_group!(micro_benches, bench_micro_hex_digit, bench_micro_hex_byte);
criterion_group!(nano_benches, bench_nano_hex_digit, bench_nano_hex_byte);
criterion_main!(decode_benches, micro_benches, nano_benches);

View File

@ -1,5 +1,6 @@
#![feature(array_chunks)]
#![feature(const_slice_index)]
#![feature(const_trait_impl)]
#![feature(extend_one)]
#![feature(generic_const_exprs)]
#![feature(int_log)]
@ -9,19 +10,26 @@
#![feature(const_maybe_uninit_array_assume_init)]
#![feature(const_maybe_uninit_uninit_array)]
#![feature(new_uninit)]
#![feature(portable_simd)]
pub(crate) mod util;
pub(crate) mod simd;
use std::mem::MaybeUninit;
use std::simd::*;
// use the maximum batch size that would be supported by AVX-512
//pub const SIMD_WIDTH: usize = 512;
pub const SIMD_WIDTH: usize = 256;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = 512 / 16;
pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
/// The batch size used for the hex digits.
// use the maximum batch size that would be supported by AVX-512
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
#[inline]
const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
let mut mask = [false; N];
let mut i = 0;
@ -39,6 +47,7 @@ const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
mask
}
#[inline]
const fn msb_lsb_indices<const N: usize>() -> [usize; N] {
if N % 2 != 0 {
panic!("Illegal N");
@ -55,6 +64,7 @@ const fn msb_lsb_indices<const N: usize>() -> [usize; N] {
indices
}
#[inline]
const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] {
let mut indices = [0; N];
let mut i = 0;
@ -75,6 +85,10 @@ const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] {
const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true);
const LSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(false);
pub const INVALID_BIT: u8 = 0b1000_0000;
pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000;
const ASCII_DIGITS: [u8; 256] = {
let mut digits = [0u8; 256];
let mut i = u8::MIN;
@ -90,7 +104,7 @@ const ASCII_DIGITS: [u8; 256] = {
DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN,
LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN,
UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN,
_ => 255,
_ => INVALID_BIT,
};
i += 1;
@ -98,7 +112,7 @@ const ASCII_DIGITS: [u8; 256] = {
digits
};
/// Returns 255 if invalid. Based on `char.to_digit()` in the stdlib.
/// Returns [`INVALID_BIT`] if invalid. Based on `char.to_digit()` in the stdlib.
#[inline]
pub const fn hex_digit(ascii: u8) -> u8 {
// use std::ops::RangeInclusive;
@ -112,7 +126,7 @@ pub const fn hex_digit(ascii: u8) -> u8 {
// DIGIT_MIN..=DIGIT_MAX => ascii - DIGIT_MIN,
// LOWER_MIN..=LOWER_MAX => 10 + ascii - LOWER_MIN,
// UPPER_MIN..=UPPER_MAX => 10 + ascii - UPPER_MIN,
// _ => 255,
// _ => INVALID_BIT,
// }
ASCII_DIGITS[ascii as usize]
// let mut digit = ascii.wrapping_sub('0' as u8);
@ -124,69 +138,162 @@ pub const fn hex_digit(ascii: u8) -> u8 {
// if digit < 6 {
// return digit + 10;
// }
// return 255;
// return INVALID_BIT;
}
#[inline(always)]
pub fn hex_digits<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES> where LaneCount<LANES>: SupportedLaneCount {
unsafe { Simd::gather_select_unchecked(&ASCII_DIGITS, Mask::splat(true), ascii.cast(), Simd::splat(0)) }
pub fn hex_digit_simd<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
unsafe {
Simd::gather_select_unchecked(
&ASCII_DIGITS,
Mask::splat(true),
ascii.cast(),
simd::splat_0::<LANES>(),
)
}
}
/// Parses an ascii hex byte.
#[inline]
#[inline(always)]
pub const fn hex_byte(msb: u8, lsb: u8) -> Option<u8> {
let msb = hex_digit(msb);
let lsb = hex_digit(lsb);
if (msb | lsb) == 255 {
// second is faster (perhaps it pipelines better?)
//if (msb | lsb) & INVALID_BIT != 0 {
if (msb & INVALID_BIT) | (lsb & INVALID_BIT) != 0 {
return None;
}
Some(msb << 4 | lsb)
}
/// Parses an ascii hex byte. Any value > [`u8::MAX`] is invalid.
#[inline]
pub const fn hex_byte_niched(msb: u8, lsb: u8) -> u16 {
let msb = hex_digit(msb) as u16;
let lsb = hex_digit(lsb) as u16;
(msb << 4) | (lsb & 0xf) | ((lsb & 0xf0) << 8)
}
/// A decoder for a single hex byte.
#[const_trait]
pub trait HexByteDecoder {
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_unpacked(hi: u8, lo: u8) -> u16;
/// Parses an ascii hex byte.
#[inline]
pub const fn hex_byte_packed([msb, lsb]: &[u8; 2]) -> Option<u8> {
let msb = hex_digit(*msb);
let lsb = hex_digit(*lsb);
if (msb | lsb) == 255 {
return None;
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
Self::decode_unpacked(*hi, *lo)
}
Some(msb << 4 | lsb)
}
/// Parses an ascii hex byte. Any value > [`u8::MAX`] is invalid.
#[inline]
pub const fn hex_byte_packed_niched([msb, lsb]: &[u8; 2]) -> u16 {
let msb = hex_digit(*msb) as u16;
let lsb = hex_digit(*lsb) as u16;
(msb << 4) | (lsb & 0xf) | ((lsb & 0xf0) << 8)
/// A decoder for a sized batch of hex bytes.
pub trait HexByteSimdDecoder {
/// Parses an ascii hex byte. Any element of the return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
}
#[inline]
pub fn hex_byte_simd(hex_bytes: Simd<u8, DIGIT_BATCH_SIZE>) -> Simd<u16, WIDE_BATCH_SIZE> {
let hex_digits = hex_digits(hex_bytes);
//println!("hex_digits: {hex_digits:04x?}");
//println!("MSB_INDICES: {MSB_INDICES:02?}");
//println!("LSB_INDICES: {LSB_INDICES:02?}");
let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
//println!("msb: {msb:04x?}");
//println!("lsb: {lsb:04x?}");
let msb = msb.cast::<u16>();
let lsb = lsb.cast::<u16>();
msb << Simd::splat(4) | lsb | ((lsb & Simd::splat(0xf0)) << Simd::splat(8))
pub struct HexByteDecoderA;
impl const HexByteDecoder for HexByteDecoderA {
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let hi = hex_digit(hi) as u16;
let lo = hex_digit(lo) as u16;
// might these these masks allow the ORs the be pipelined more efficiently?
(hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let hi = hex_digit(*hi) as u16;
let lo = hex_digit(*lo) as u16;
(hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8)
}
}
impl HexByteSimdDecoder for HexByteDecoderA {
#[inline(always)]
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
let hex_digits = hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(hi_los));
if ((hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
.simd_ne(simd::splat_0::<DIGIT_BATCH_SIZE>()))
.any()
{
return None;
}
let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
/*let msb = msb.cast::<u16>();
let lsb = lsb.cast::<u16>();
let buf = msb << simd::splat_n::<WIDE_BATCH_SIZE>(4) | lsb | ((lsb & simd::splat_n::<WIDE_BATCH_SIZE>(0xf0)) << simd::splat_n::<WIDE_BATCH_SIZE>(8));
if buf.simd_gt(simd::splat_n::<WIDE_BATCH_SIZE>(u8::MAX as u16)).any() {
return None;
}
Some(buf.cast::<u8>())*/
Some((msb << simd::splat_n::<WIDE_BATCH_SIZE>(4)) | lsb)
}
}
pub struct HexByteDecoderB;
impl const HexByteDecoder for HexByteDecoderB {
util::defer_impl! {
=> HexByteDecoderA;
//fn decode_unpacked(hi: u8, lo: u8) -> u16;
//fn decode_packed(hi_lo: &[u8; 2]) -> u16;
}
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let lo = hex_digit(lo) as u16;
let hi = hex_digit(hi) as u16;
// kind of bizarre: changing the order of these decreases perf by 6-12%
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let lo = hex_digit(*lo) as u16;
let hi = hex_digit(*hi) as u16;
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
}
impl HexByteSimdDecoder for HexByteDecoderB {
util::defer_impl! {
=> HexByteDecoderA;
//fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
}
#[inline(always)]
fn decode_simd(mut hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
for b in hi_los.iter_mut() {
*b = hex_digit(*b);
}
let hex_digits = Simd::from_array(hi_los);
if (hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
.simd_ne(simd::splat_0::<DIGIT_BATCH_SIZE>())
.any()
{
//if hex_digits.simd_eq(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT)).any() {
return None;
}
let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
let mut v = Simd::from_array([0u8; WIDE_BATCH_SIZE]);
for (i, v) in v.as_mut_array().iter_mut().enumerate() {
let hi = unsafe { *msb.as_array().get_unchecked(i) };
let lo = unsafe { *lsb.as_array().get_unchecked(i) };
*v = (hi << 4) | lo;
}
Some(v)
//msb << simd::splat_n::<WIDE_BATCH_SIZE>(4) | lsb | ((lsb & simd::splat_n::<WIDE_BATCH_SIZE>(0xf0)) << simd::splat_n::<WIDE_BATCH_SIZE>(8))
}
}
pub type HBD = HexByteDecoderB;
pub mod conv {
use std::simd::{Simd, LaneCount, SupportedLaneCount};
use std::simd::{LaneCount, Simd, SupportedLaneCount};
/*trait Size {
const N: usize;
@ -241,19 +348,29 @@ pub mod conv {
size_of_impl!(u64 = SizeU64);*/
#[allow(non_camel_case_types, non_snake_case)]
union u8_u16<const N_u16: usize> where [u8; N_u16 * 2]: {
union u8_u16<const N_u16: usize>
where
[u8; N_u16 * 2]:,
{
u8: [u8; N_u16 * 2],
u16: [u16; N_u16],
}
#[allow(non_camel_case_types, non_snake_case)]
union u8x2_u8<const N_u8x2: usize> where [u8; N_u8x2 * 2]: {
union u8x2_u8<const N_u8x2: usize>
where
[u8; N_u8x2 * 2]:,
{
u8x2: [[u8; 2]; N_u8x2],
u8: [u8; N_u8x2 * 2],
}
#[allow(non_camel_case_types, non_snake_case)]
union SimdU8_SimdU16<const N_U16: usize> where LaneCount<{ N_U16 * 2 }>: SupportedLaneCount, LaneCount<N_U16>: SupportedLaneCount {
union SimdU8_SimdU16<const N_U16: usize>
where
LaneCount<{ N_U16 * 2 }>: SupportedLaneCount,
LaneCount<N_U16>: SupportedLaneCount,
{
SimdU8: Simd<u8, { N_U16 * 2 }>,
SimdU16: Simd<u16, N_U16>,
}
@ -269,10 +386,12 @@ pub mod conv {
}
#[inline(always)]
pub const fn simdu8_to_simdu16<const N_OUT: usize>(a: Simd<u8, { N_OUT * 2 }>) -> Simd<u16, N_OUT>
pub const fn simdu8_to_simdu16<const N_OUT: usize>(
a: Simd<u8, { N_OUT * 2 }>,
) -> Simd<u16, N_OUT>
where
LaneCount<{ N_OUT * 2 }>: SupportedLaneCount,
LaneCount<N_OUT>: SupportedLaneCount
LaneCount<N_OUT>: SupportedLaneCount,
{
unsafe { SimdU8_SimdU16 { SimdU8: a }.SimdU16 }
}
@ -296,16 +415,15 @@ const fn align_up_to<const N: usize>(n: usize) -> usize {
return (n + (N - 1)) >> shift << shift;
}
macro_rules! decode_hex_bytes_non_vectored {
($i:ident, $ascii:ident, $bytes:ident, $o:expr) => {{
while $i < $ascii.len() {
match unsafe { hex_byte(*$ascii.get_unchecked($i), *$ascii.get_unchecked($i + 1)) } {
Some(b) => unsafe { *$bytes.get_unchecked_mut($o + ($i >> 1)) = MaybeUninit::new(b) },
None => {
println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
//println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
return false
},
}
}
$i += 2;
}
@ -314,29 +432,34 @@ macro_rules! decode_hex_bytes_non_vectored {
#[inline(always)]
fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bool {
debug_assert_eq!(ascii.len() >> 1 << 1, ascii.len(), "len of ascii is not a multiple of 2");
debug_assert_eq!(ascii.len() >> 1, bytes.len(), "len of ascii is not twice that of bytes");
const VECTORED: bool = true;
if VECTORED {
union Aligned {
bytes: [u8; DIGIT_BATCH_SIZE],
simd: Simd<u8, DIGIT_BATCH_SIZE>,
}
debug_assert_eq!(
ascii.len() >> 1 << 1,
ascii.len(),
"len of ascii is not a multiple of 2"
);
debug_assert_eq!(
ascii.len() >> 1,
bytes.len(),
"len of ascii is not twice that of bytes"
);
const VECTORED_A: bool = false;
const VECTORED_B: bool = false;
const VECTORED_C: bool = false;
if VECTORED_A {
let mut i = 0;
while i < align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let slice = unsafe { &*(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE]) };
let aligned = Aligned { bytes: *slice };
let buf = hex_byte_simd(unsafe { aligned.simd });
if buf > Simd::splat(u8::MAX as u16) {
println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE);
return false;
}
let buf = buf.cast::<u8>();
let buf = HBD::decode_simd(unsafe {
*(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE])
});
let buf = match buf {
Some(buf) => buf,
None => return false,
};
let mut j = 0;
while j < DIGIT_BATCH_SIZE {
unsafe {
*bytes.get_unchecked_mut((i >> 1) + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j))
*bytes.get_unchecked_mut((i >> 1) + j) =
MaybeUninit::new(*buf.as_array().get_unchecked(j))
};
j += 1;
}
@ -344,28 +467,30 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
}
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
} else if false {
let (ascii_pre, ascii_simd, ascii_post) = unsafe { ascii.align_to::<Simd<u8, DIGIT_BATCH_SIZE>>() };
} else if VECTORED_B {
let (ascii_pre, ascii_simd, ascii_post) =
unsafe { ascii.align_to::<Simd<u8, DIGIT_BATCH_SIZE>>() };
assert_eq!(ascii_pre.len() % 2, 0);
assert_eq!(ascii_post.len() % 2, 0);
debug_assert_eq!(ascii_pre.len() % 2, 0);
debug_assert_eq!(ascii_post.len() % 2, 0);
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii_pre, bytes, 0);
let mut i = 0;
while i < ascii_simd.len() {
let buf = hex_byte_simd(unsafe { *ascii_simd.get_unchecked(i) });
if buf > Simd::splat(u8::MAX as u16) {
println!("bad hex byte at {}..{}", i, i + DIGIT_BATCH_SIZE);
return false;
}
let buf = buf.cast::<u8>();
// this to_array and any subsequent from_array should be eliminated anyway
let buf = HBD::decode_simd(unsafe { ascii_simd.get_unchecked(i) }.to_array());
let buf = match buf {
Some(buf) => buf,
None => return false,
};
let mut j = 0;
let k = ascii_pre.len() + i * DIGIT_BATCH_SIZE;
while j < DIGIT_BATCH_SIZE {
unsafe {
*bytes.get_unchecked_mut(k + j) = MaybeUninit::new(*buf.as_array().get_unchecked(j))
*bytes.get_unchecked_mut(k + j) =
MaybeUninit::new(*buf.as_array().get_unchecked(j))
};
j += 1;
}
@ -375,6 +500,28 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
let mut i = 0;
let k = ascii.len() - ascii_post.len();
decode_hex_bytes_non_vectored!(i, ascii_post, bytes, k);
} else if VECTORED_C {
let mut i = 0;
while i < align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let buf = HBD::decode_simd(unsafe {
*(ascii.as_ptr().add(i) as *const [u8; DIGIT_BATCH_SIZE])
});
let buf = match buf {
Some(buf) => buf,
None => return false,
};
let mut j = 0;
while j < DIGIT_BATCH_SIZE {
unsafe {
*bytes.get_unchecked_mut((i >> 1) + j) =
MaybeUninit::new(*buf.as_array().get_unchecked(j))
};
j += 1;
}
i += DIGIT_BATCH_SIZE;
}
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
} else {
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes, 0);
@ -391,14 +538,17 @@ pub const fn hex_bytes_sized_const<const N: usize>(ascii: &[u8; N * 2]) -> Optio
} else {
let mut bytes = MaybeUninit::<u8>::uninit_array::<N>();
let mut i = 0;
while i < bytes.len() {
match hex_byte(unsafe { *ascii.get_unchecked(i << 1) }, unsafe {
*ascii.get_unchecked((i << 1) + 1)
while i < N * 2 {
if i >> 1 >= bytes.len() {
unsafe { std::hint::unreachable_unchecked() };
}
match hex_byte(unsafe { *ascii.get_unchecked(i) }, unsafe {
*ascii.get_unchecked(i + 1)
}) {
Some(b) => bytes[i] = MaybeUninit::new(b),
Some(b) => bytes[i >> 1] = MaybeUninit::new(b),
None => return None,
}
i += 1;
i += 2;
}
Some(unsafe { MaybeUninit::array_assume_init(bytes) })
}
@ -418,12 +568,12 @@ pub fn hex_bytes_sized<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
}
}
#[inline]
pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[u8; N]>> {
if N == 0 {
Some(Box::new([0u8; N]))
} else {
let mut bytes: Box<[MaybeUninit<u8>; N]> =
unsafe { Box::<[MaybeUninit<u8>; N]>::new_uninit().assume_init() };
let mut bytes = unsafe { Box::<[MaybeUninit<u8>; N]>::new_uninit().assume_init() };
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8; N]) })
} else {
@ -432,19 +582,21 @@ pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[
}
}
#[inline]
pub fn hex_bytes_dyn_unsafe(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
return None;
}
let mut bytes = Box::<[u8]>::new_uninit_slice(len);
let mut bytes = Box::new_uninit_slice(len);
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) })
Some(unsafe { Box::<[_]>::assume_init(bytes) })
} else {
None
}
}
#[inline]
pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
@ -462,9 +614,10 @@ pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> {
return None;
}
}
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) })
Some(unsafe { Box::<[_]>::assume_init(bytes) })
}
#[inline]
pub fn hex_bytes_dyn_unsafe_iter_niched(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
@ -473,18 +626,18 @@ pub fn hex_bytes_dyn_unsafe_iter_niched(ascii: &[u8]) -> Option<Box<[u8]>> {
let mut bytes = Box::<[u8]>::new_uninit_slice(len);
for (i, b) in ascii
.array_chunks::<2>()
.map(|[msb, lsb]| hex_byte_niched(*msb, *lsb))
.map(HBD::decode_packed)
.enumerate()
{
if b & 0xff_00 == 0 {
unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b as u8) };
} else {
if b & WIDE_INVALID_BIT != 0 {
return None;
}
unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b as u8) };
}
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8]) })
Some(unsafe { Box::<[_]>::assume_init(bytes) })
}
#[inline]
pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> {
let iter = ascii.array_chunks::<2>();
if iter.remainder().len() != 0 {
@ -495,13 +648,20 @@ pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> {
.map(|v| v.into_boxed_slice())
}
#[inline]
pub fn hex_bytes_dyn_niched(ascii: &[u8]) -> Option<Box<[u8]>> {
let iter = ascii.array_chunks::<2>();
if iter.remainder().len() != 0 {
return None;
}
iter.map(|[msb, lsb]| hex_byte_niched(*msb, *lsb))
.map(|b| if b & 0xff_00 == 0 { Some(b as u8) } else { None })
iter.map(HBD::decode_packed)
.map(|b| {
if b & WIDE_INVALID_BIT != 0 {
None
} else {
Some(b as u8)
}
})
.collect::<Option<Vec<u8>>>()
.map(|v| v.into_boxed_slice())
}
@ -522,8 +682,21 @@ mod test {
}
const SAMPLES: &[Sample] = &[
Sample { bytes: BYTES, hex_bytes: HEX_BYTES },
Sample { bytes: LONG_BYTES, hex_bytes: LONG_HEX_BYTES },
Sample {
bytes: BYTES,
hex_bytes: HEX_BYTES,
},
Sample {
bytes: LONG_BYTES,
hex_bytes: LONG_HEX_BYTES,
},
];
const INVALID_SAMPLES: &[&[u8]] = &[
b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722G",
b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E7365637465747572GE",
b"446F6C6F72756D2064697374696E6374696G20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
b"446F6C6F72756D2064697374696E637469GF20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
];
#[test]
@ -547,8 +720,10 @@ mod test {
#[test]
fn test_hex_digit_simd() {
const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0',
'1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E',
'F',
// '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
];
let mut set8 = [0u8; DIGIT_BATCH_SIZE];
@ -559,7 +734,10 @@ mod test {
for i in 0..(DIGIT_BATCH_SIZE) {
sete[i] = (i as u8) % 16;
}
assert_eq!(hex_digits(Simd::from_array(set8)), Simd::from_array(sete));
assert_eq!(
hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(set8)),
Simd::from_array(sete)
);
}
#[test]
@ -576,7 +754,12 @@ mod test {
for (hb, b) in HEX_BYTES_VALID {
assert_eq!(hex_byte(hb[0], hb[1]), Some(*b));
assert_eq!(hex_byte_niched(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderA::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderB::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderA::decode_packed(hb), *b as u16);
assert_eq!(HexByteDecoderB::decode_packed(hb), *b as u16);
}
const HEX_BYTES_INVALID: &[[u8; 2]] = &[
@ -591,39 +774,47 @@ mod test {
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & 0xff_00 != 0);
assert_ne!(
HexByteDecoderA::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(
HexByteDecoderB::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(HexByteDecoderA::decode_packed(hb) & WIDE_INVALID_BIT, 0);
assert_ne!(HexByteDecoderB::decode_packed(hb) & WIDE_INVALID_BIT, 0);
}
}
#[test]
fn test_hex_byte_simd() {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11",
*b"ef", *b"fe", *b"0f", *b"f0",
*b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f,
0xf0,
0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
];
let hex_bytes = Simd::from_array(conv::u8x2_to_u8(HEX_BYTES_VALID));
let bytes = BYTES_VALID.map(|b| b as u16);
let bytes = Simd::from_array(bytes);
let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID);
let bytes = Simd::from_array(BYTES_VALID);
println!("hex_bytes: {HEX_BYTES_VALID:02x?}");
println!("hex_bytes: {hex_bytes:02x?}");
println!("bytes: {BYTES_VALID:02x?}");
println!("bytes: {bytes:04x?}");
assert_eq!(hex_byte_simd(hex_bytes), bytes);
assert_eq!(HexByteDecoderA::decode_simd(hex_bytes), Some(bytes));
assert_eq!(HexByteDecoderB::decode_simd(hex_bytes), Some(bytes));
/*const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
@ -637,14 +828,27 @@ mod test {
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & 0xff_00 != 0);
assert!(hex_byte_niched(hb[0], hb[1]) & WIDE_INVALID_BIT != 0);
}*/
}
fn test_f(f: fn(&[u8]) -> Option<Box<[u8]>>) {
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
let result = f(hex_bytes);
assert_eq!(Some(*bytes), result.as_ref().map(Box::as_ref), "Sample {i} did not decode correctly");
assert_eq!(
Some(*bytes),
result.as_ref().map(Box::as_ref),
"Sample {i} did not decode correctly"
);
}
for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() {
let result = f(hex_bytes);
assert_eq!(
None,
result.as_ref().map(Box::as_ref),
"Sample {i} did not decode correctly"
);
}
}

89
src/simd.rs Normal file
View File

@ -0,0 +1,89 @@
use std::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::util::cast;
pub trait SimdSplatZero<const LANES: usize> {
fn splat_zero() -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount;
}
pub trait SimdSplatN<const LANES: usize> {
fn splat_n(n: u8) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount;
}
pub struct SimdOps;
const W_128: usize = 128 / 8;
const W_256: usize = 256 / 8;
const W_512: usize = 512 / 8;
macro_rules! specialized {
($LANES:ident, $trait:ident {
$(
fn $name:ident($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)*
}) => {
impl<const $LANES: usize> $trait<$LANES> for SimdOps {
$(
#[inline(always)]
fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
}
}
)*
}
};
}
specialized!(LANES, SimdSplatN {
fn splat_n(n: u8) -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
_ => Simd::splat(n),
}
});
specialized!(LANES, SimdSplatZero {
fn splat_zero() -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => <Self as SimdSplatN<LANES>>::splat_n(0),
}
});
#[inline(always)]
pub fn splat_0<const LANES: usize>() -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
<SimdOps as SimdSplatZero<LANES>>::splat_zero()
}
#[inline(always)]
pub fn splat_n<const LANES: usize>(n: u8) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
<SimdOps as SimdSplatN<LANES>>::splat_n(n)
}

55
src/util.rs Normal file
View File

@ -0,0 +1,55 @@
#[doc(hidden)]
#[macro_export]
macro_rules! __defer_impl {
(
=> $impl:ident;
$( fn $name:ident($( $pname:ident: $pty:ty ),*) -> $rty:ty; )*
) => {
$(
#[inline(always)]
fn $name($( $pname: $pty ),*) -> $rty {
<$impl>::$name($( $pname ),*)
}
)*
};
}
pub use __defer_impl as defer_impl;
#[inline(always)]
#[cold]
pub fn cold() {}
#[inline]
pub fn likely(b: bool) -> bool {
if !b {
cold()
}
b
}
#[inline]
pub fn unlikely(b: bool) -> bool {
if b {
cold()
}
b
}
/// Like transmute, but implemented via a union so that we can use it in situations where
/// transmute's "safety" restrictions are too strict and uninformed (i.e. we can prove it is safe).
#[inline(always)]
pub unsafe fn cast<A, B>(a: A) -> B {
union Cast<A, B> {
a: std::mem::ManuallyDrop<A>,
b: std::mem::ManuallyDrop<B>,
}
std::mem::ManuallyDrop::into_inner(
Cast {
a: std::mem::ManuallyDrop::new(a),
}
.b,
)
}