Optimized vectorized validation
- Overhead of validation is cut down to 5% tops
This commit is contained in:
parent
8210b6ac40
commit
ee3b6d84e4
17
src/lib.rs
17
src/lib.rs
|
@ -30,6 +30,7 @@ pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
|
|||
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
|
||||
|
||||
const TRACE_SIMD: bool = false;
|
||||
const VALIDATE: bool = true;
|
||||
|
||||
#[inline]
|
||||
const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
|
||||
|
@ -497,8 +498,8 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
|
|||
#[cfg(target_arch = "x86_64")]
|
||||
use std::arch::x86_64 as arch;
|
||||
|
||||
let mut bad: arch::__m256i = simd::splat_0().into();
|
||||
let mut i = 0;
|
||||
let mut bad = Mask::splat(false);
|
||||
while i < util::align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
|
||||
const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
|
||||
let hex_digits = unsafe {
|
||||
|
@ -590,9 +591,12 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
|
|||
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
|
||||
simd::merge_m128_m256(ab, cd)
|
||||
};
|
||||
if VALIDATE {
|
||||
unsafe {
|
||||
std::arch::asm!("vpor {bad}, {digits}, {bad}", bad = inout(ymm_reg) bad, digits = in(ymm_reg) hex_digits);
|
||||
}
|
||||
}
|
||||
let hex_digits: Simd<u8, DIGIT_BATCH_SIZE> = hex_digits.into();
|
||||
bad |= (hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
|
||||
.simd_ne(simd::splat_0::<DIGIT_BATCH_SIZE>());
|
||||
if TRACE_SIMD {
|
||||
println!("hex_digits: {hex_digits:x?}");
|
||||
}
|
||||
|
@ -658,7 +662,8 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
|
|||
}
|
||||
|
||||
decode_hex_bytes_non_vectored!(i, ascii, bytes);
|
||||
!bad.any()
|
||||
use simd::SimdTestAnd;
|
||||
!bad.test_and_non_zero(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT).into())
|
||||
} else {
|
||||
let mut i = 0;
|
||||
decode_hex_bytes_non_vectored!(i, ascii, bytes);
|
||||
|
@ -942,8 +947,8 @@ mod test {
|
|||
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
|
||||
let result = $f(hex_bytes.as_bytes());
|
||||
assert_eq!(
|
||||
Some(bytes.as_bytes()),
|
||||
result.as_ref().map($trans),
|
||||
Some(bytes.as_bytes()),
|
||||
"Sample {i} ({hex_bytes:?} => {bytes:?}) did not decode correctly (expected Some)"
|
||||
);
|
||||
}
|
||||
|
@ -951,8 +956,8 @@ mod test {
|
|||
for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() {
|
||||
let result = $f(hex_bytes.as_bytes());
|
||||
assert_eq!(
|
||||
None,
|
||||
result.as_ref().map($trans),
|
||||
None,
|
||||
"Sample {i} ({hex_bytes:?}) did not decode correctly (expected None)"
|
||||
);
|
||||
}
|
||||
|
|
29
src/simd.rs
29
src/simd.rs
|
@ -201,3 +201,32 @@ pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i {
|
|||
pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i {
|
||||
extract_lohi_bytes!(([0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0], vpshufb, vpunpcklqdq), v)
|
||||
}
|
||||
|
||||
pub trait SimdTestAnd {
|
||||
/// Returns true if the result of the bitwise AND of `self` and `mask` is not all zero.
|
||||
fn test_and_non_zero(self, mask: Self) -> bool;
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
impl SimdTestAnd for arch::__m128i {
|
||||
#[inline(always)]
|
||||
fn test_and_non_zero(self, mask: Self) -> bool {
|
||||
unsafe {
|
||||
let out: u8;
|
||||
std::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = out(reg_byte) out);
|
||||
std::mem::transmute(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
impl SimdTestAnd for arch::__m256i {
|
||||
#[inline(always)]
|
||||
fn test_and_non_zero(self, mask: Self) -> bool {
|
||||
unsafe {
|
||||
let out: u8;
|
||||
std::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = out(reg_byte) out);
|
||||
std::mem::transmute(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue