Optimized vectorized validation

- Overhead of validation is cut down to 5% tops
This commit is contained in:
Michael Pfaff 2022-10-31 23:57:40 -04:00
parent 8210b6ac40
commit ee3b6d84e4
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
2 changed files with 40 additions and 6 deletions

View File

@ -30,6 +30,7 @@ pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
const TRACE_SIMD: bool = false;
const VALIDATE: bool = true;
#[inline]
const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
@ -497,8 +498,8 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64 as arch;
let mut bad: arch::__m256i = simd::splat_0().into();
let mut i = 0;
let mut bad = Mask::splat(false);
while i < util::align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
let hex_digits = unsafe {
@ -590,9 +591,12 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
simd::merge_m128_m256(ab, cd)
};
if VALIDATE {
unsafe {
std::arch::asm!("vpor {bad}, {digits}, {bad}", bad = inout(ymm_reg) bad, digits = in(ymm_reg) hex_digits);
}
}
let hex_digits: Simd<u8, DIGIT_BATCH_SIZE> = hex_digits.into();
bad |= (hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
.simd_ne(simd::splat_0::<DIGIT_BATCH_SIZE>());
if TRACE_SIMD {
println!("hex_digits: {hex_digits:x?}");
}
@ -658,7 +662,8 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
}
decode_hex_bytes_non_vectored!(i, ascii, bytes);
!bad.any()
use simd::SimdTestAnd;
!bad.test_and_non_zero(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT).into())
} else {
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes);
@ -942,8 +947,8 @@ mod test {
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
Some(bytes.as_bytes()),
result.as_ref().map($trans),
Some(bytes.as_bytes()),
"Sample {i} ({hex_bytes:?} => {bytes:?}) did not decode correctly (expected Some)"
);
}
@ -951,8 +956,8 @@ mod test {
for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
None,
result.as_ref().map($trans),
None,
"Sample {i} ({hex_bytes:?}) did not decode correctly (expected None)"
);
}

View File

@ -201,3 +201,32 @@ pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i {
pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i {
extract_lohi_bytes!(([0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0], vpshufb, vpunpcklqdq), v)
}
pub trait SimdTestAnd {
/// Returns true if the result of the bitwise AND of `self` and `mask` is not all zero.
fn test_and_non_zero(self, mask: Self) -> bool;
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m128i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
std::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = out(reg_byte) out);
std::mem::transmute(out)
}
}
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m256i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
std::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = out(reg_byte) out);
std::mem::transmute(out)
}
}
}