fast-hex/src/dec.rs

617 lines
20 KiB
Rust

//! SIMD-accelerated, validating hex decoding.
use core::mem::MaybeUninit;
use core::simd::*;
#[cfg(feature = "alloc")]
use alloc::{boxed::Box, vec::Vec};
use crate::prelude::*;
const VALIDATE: bool = true;
pub const INVALID_BIT: u8 = 0b1000_0000;
pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000;
#[repr(align(128))]
struct CachePadded<T>(T);
const ASCII_DIGITS: CachePadded<[u8; 256]> = {
CachePadded(array_op!(
gen[256] | i | {
const DIGIT_MIN: u8 = '0' as u8;
const DIGIT_MAX: u8 = '9' as u8;
const LOWER_MIN: u8 = 'a' as u8;
const LOWER_MAX: u8 = 'f' as u8;
const UPPER_MIN: u8 = 'A' as u8;
const UPPER_MAX: u8 = 'F' as u8;
let i = i as u8;
match i {
DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN,
LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN,
UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN,
_ => INVALID_BIT,
}
}
))
};
const __ASCII_DIGITS_SIMD: CachePadded<[u32; 256]> = CachePadded(util::cast_u8_u32(ASCII_DIGITS.0));
const ASCII_DIGITS_SIMD: *const i32 = __ASCII_DIGITS_SIMD.0.as_ptr() as *const i32;
/// Returns [`INVALID_BIT`] if invalid. Based on `char.to_digit()` in the stdlib.
#[inline]
pub const fn hex_digit(ascii: u8) -> u8 {
ASCII_DIGITS.0[ascii as usize]
}
#[inline(always)]
pub fn hex_digit_simd<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
unsafe {
Simd::gather_select_unchecked(
&ASCII_DIGITS.0,
Mask::splat(true),
ascii.cast(),
u8::splat_zero(),
)
}
}
/// Parses an ascii hex byte.
#[inline(always)]
pub const fn hex_byte(msb: u8, lsb: u8) -> Option<u8> {
let msb = hex_digit(msb);
let lsb = hex_digit(lsb);
// second is faster (perhaps it pipelines better?)
//if (msb | lsb) & INVALID_BIT != 0 {
if (msb & INVALID_BIT) | (lsb & INVALID_BIT) != 0 {
return None;
}
Some(msb << 4 | lsb)
}
/// A decoder for a single hex byte.
#[const_trait]
pub trait HexByteDecoder {
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_unpacked(hi: u8, lo: u8) -> u16;
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
Self::decode_unpacked(*hi, *lo)
}
}
/// A decoder for a sized batch of hex bytes.
pub trait HexByteSimdDecoder {
/// Parses an ascii hex byte. Any element of the return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
}
pub struct HexByteDecoderA;
impl const HexByteDecoder for HexByteDecoderA {
// util::defer_impl! {
// => HexByteDecoderA;
//
// fn decode_unpacked(hi: u8, lo: u8) -> u16;
//
// fn decode_packed(hi_lo: &[u8; 2]) -> u16;
// }
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let lo = hex_digit(lo) as u16;
let hi = hex_digit(hi) as u16;
// kind of bizarre: changing the order of these decreases perf by 6-12%
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let lo = hex_digit(*lo) as u16;
let hi = hex_digit(*hi) as u16;
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
}
macro_rules! hex_digits_simd_inline {
($ptr:ident) => {{
if_trace_simd! {
println!("hi_los: {:x?}", *$ptr);
}
unroll!(let [a: (0), b: (1), c: (2), d: (3)] => |i| *$ptr.add(i));
if_trace_simd! {
let f = |x| __ASCII_DIGITS_SIMD[x as usize];
println!(
"{:x?}, {:x?}, {:x?}, {:x?}",
a.map(f),
b.map(f),
c.map(f),
d.map(f)
);
}
unroll!(let [a, b, c, d] => |x| Simd::from_array(x));
unroll!(let [a, b, c, d] => |x| x.cast::<u32>());
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
unroll!(let [a, b, c, d] => |x| x.into());
unroll!(let [a, b, c, d] => |x| simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, x, 4));
unroll!(let [a, b, c, d] => |x| Simd::<u32, GATHER_BATCH_SIZE>::from(x).cast::<u8>());
if_trace_simd! {
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
unroll!(let [a, b, c, d] => |x| util::cast::<_, u64>(x));
let abcd: simd::arch::__m256i = u64::load_array([a, b, c, d]).into();
if_trace_simd! {
let abcd: Simd<u8, DIGIT_BATCH_SIZE> = abcd.into();
println!("abcd: {abcd:x?}");
}
abcd
}};
}
macro_rules! merge_hex_digits_into_bytes_inline {
($hex_digits:ident) => {{
let msb = simd::extract_lo_bytes($hex_digits);
let lsb = simd::extract_hi_bytes($hex_digits);
let msb1: simd::arch::__m128i;
unsafe { std::arch::asm!("vpsllq {dst}, {src}, 4", src = in(xmm_reg) msb, dst = lateout(xmm_reg) msb1, options(pure, nomem, preserves_flags, nostack)) };
if_trace_simd! {
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
println!("msb1: {msb1:x?}");
}
let msb2 = msb1.and(0xf0u8.splat().into());
let b = msb2.or(lsb);
if_trace_simd! {
unroll!(let [msb, msb1, msb2, lsb, b] => |x| Simd::<u8, WIDE_BATCH_SIZE>::from(x));
println!("| Packed | Msb | <<4 | & | Lsb | Bytes | |");
Simd::<u8, DIGIT_BATCH_SIZE>::from($hex_digits)
.to_array()
.chunks(2)
.zip(msb.to_array())
.zip(msb1.to_array())
.zip(msb2.to_array())
.zip(lsb.to_array())
.zip(b.to_array())
.for_each(|(((((chunk, msb), msb1), msb2), lsb), b)| {
println!(
"| {chunk:02x?} | {msb:x?} | {msb1:x?} | {msb2:x?} | {lsb:x?} | {b:02x?} | {ok} |",
chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16),
ok = if chunk[0] == msb && chunk[1] == lsb {
'✓'
} else {
'✗'
}
);
});
}
b
}};
}
impl HexByteSimdDecoder for HexByteDecoderA {
// util::defer_impl! {
// => HexByteDecoderA;
//
// fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
// }
#[inline(always)]
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
let hi_los = hi_los.as_ptr() as *const [u8; GATHER_BATCH_SIZE];
let hex_digits = unsafe { hex_digits_simd_inline!(hi_los) };
if hex_digits.test_and_non_zero(INVALID_BIT.splat().into()) {
return None;
}
Some(merge_hex_digits_into_bytes_inline!(hex_digits).into())
}
}
pub type HBD = HexByteDecoderA;
pub mod conv {
use brisk::util::cast;
#[inline(always)]
pub const fn u8x2_to_u8<const N_IN: usize>(a: [[u8; 2]; N_IN]) -> [u8; N_IN * 2] {
unsafe { cast(a) }
}
}
macro_rules! decode_hex_bytes_non_vectored {
($i:ident, $ascii:ident, $bytes:ident) => {{
//let mut bad = 0u16;
let mut bad = 0u8;
while $i < $ascii.len() {
/*let b = HBD::decode_packed(unsafe { &*($ascii.as_ptr().add($i) as *const [u8; 2]) });
bad |= b;
unsafe { *$bytes.get_unchecked_mut($i >> 1) = MaybeUninit::new(b as u8) };*/
let [hi, lo] = unsafe { *($ascii.as_ptr().add($i) as *const [u8; 2]) };
unroll!(let [lo, hi] => |x| hex_digit(x));
unroll!([(lo), (hi)] => |x| bad |= x);
/*if (hi & INVALID_BIT) | (lo & INVALID_BIT) != 0 {
println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
}*/
let b = (hi << 4) | lo;
unsafe { *$bytes.get_unchecked_mut($i >> 1) = MaybeUninit::new(b) };
$i += 2;
}
//if (bad & WIDE_INVALID_BIT) != 0 {
if (bad & INVALID_BIT) != 0 {
return false;
}
}};
}
#[inline(always)]
fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bool {
// these checks should always be eliminated because they are performed more efficiently
// (sometimes statically) in the callers, but they provide a major safeguard against nasty
// memory safety issues.
debug_assert_eq!(
ascii.len() >> 1 << 1,
ascii.len(),
"len of ascii is not a multiple of 2"
);
if ascii.len() >> 1 << 1 != ascii.len() {
return false;
}
debug_assert_eq!(
ascii.len() >> 1,
bytes.len(),
"len of ascii is not twice that of bytes"
);
if ascii.len() >> 1 != bytes.len() {
return false;
}
const VECTORED: bool = true;
if VECTORED {
unsafe { arch::_mm_prefetch(ASCII_DIGITS_SIMD as *const i8, arch::_MM_HINT_T0) };
let mut bad: arch::__m256i = u8::splat_zero().into();
let mut i = 0;
while i < util::align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let hex_digits = unsafe {
let hi_los = ascii.as_ptr().add(i) as *const [u8; GATHER_BATCH_SIZE];
hex_digits_simd_inline!(hi_los)
};
if VALIDATE {
unsafe {
core::arch::asm!("vpor {bad}, {digits}, {bad}", bad = inout(ymm_reg) bad, digits = in(ymm_reg) hex_digits, options(pure, nomem, preserves_flags, nostack));
}
}
let buf = merge_hex_digits_into_bytes_inline!(hex_digits);
unsafe {
// TODO: consider unrolling 2 iterations of this loop and buffering bytes in a single
// ymm register to be stored at once.
core::arch::asm!("vmovdqa [{}], {}", in(reg) bytes.as_mut_ptr().add(i >> 1) as *mut i8, in(xmm_reg) buf, options(preserves_flags, nostack));
};
i += DIGIT_BATCH_SIZE;
}
decode_hex_bytes_non_vectored!(i, ascii, bytes);
!bad.test_and_non_zero(INVALID_BIT.splat().into())
} else {
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes);
true
}
}
/// This function is a safe bet when you need to decode hex in a const context, on a system that
/// does not support AVX2, or just don't feel comfortable relying on so much unsafe code and inline
/// ASM.
///
/// It performs only 8% worse than the SIMD-accelerated implementation.
#[inline]
pub const fn hex_bytes_sized_const<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
if N == 0 {
Some([0u8; N])
} else {
let mut bytes = MaybeUninit::uninit_array();
let mut i = 0;
while i < N * 2 {
// Ensure bounds checks are removed. Might not be necessary.
if i >> 1 >= bytes.len() {
unsafe { core::hint::unreachable_unchecked() };
}
match hex_byte(unsafe { *ascii.get_unchecked(i) }, unsafe {
*ascii.get_unchecked(i + 1)
}) {
Some(b) => bytes[i >> 1] = MaybeUninit::new(b),
None => return None,
}
i += 2;
}
Some(unsafe { MaybeUninit::array_assume_init(bytes) })
}
}
#[inline]
pub fn hex_bytes_sized<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
if N == 0 {
Some([0u8; N])
} else {
let mut bytes = MaybeUninit::uninit_array();
if decode_hex_bytes_unchecked(ascii, &mut bytes) {
Some(unsafe { MaybeUninit::array_assume_init(bytes) })
} else {
None
}
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[u8; N]>> {
if N == 0 {
Some(Box::new([0u8; N]))
} else {
let mut bytes = unsafe { Box::<[_; N]>::new_uninit().assume_init() };
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8; N]) })
} else {
None
}
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn_unsafe(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
return None;
}
let mut bytes = Box::new_uninit_slice(len);
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::<[_]>::assume_init(bytes) })
} else {
None
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
return None;
}
let mut bytes = Box::<[u8]>::new_uninit_slice(len);
for (i, [hi, lo]) in ascii.array_chunks::<2>().enumerate() {
let lo = hex_digit(*lo);
let hi = hex_digit(*hi);
if (lo & INVALID_BIT) | (hi & INVALID_BIT) != 0 {
return None;
}
let b = (hi << 4) | lo;
unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b) };
}
Some(unsafe { Box::<[_]>::assume_init(bytes) })
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> {
let iter = ascii.array_chunks::<2>();
if iter.remainder().len() != 0 {
return None;
}
iter.map(|[msb, lsb]| hex_byte(*msb, *lsb))
.collect::<Option<Vec<u8>>>()
.map(|v| v.into_boxed_slice())
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::*;
#[test]
fn test_hex_digit() {
const HEX_DIGITS_LOWER: &[char; 16] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
];
const HEX_DIGITS_UPPER: &[char; 16] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F',
];
for (i, digit) in HEX_DIGITS_LOWER.into_iter().enumerate() {
assert_eq!(hex_digit(*digit as u8), i as u8);
}
for (i, digit) in HEX_DIGITS_UPPER.into_iter().enumerate() {
assert_eq!(hex_digit(*digit as u8), i as u8);
}
}
#[test]
fn test_hex_digit_simd() {
const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0',
'1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E',
'F',
// '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
];
let mut set8 = [0u8; DIGIT_BATCH_SIZE];
for (c, b) in HEX_DIGITS.iter().zip(set8.iter_mut()) {
*b = *c as u8;
}
let mut sete = [0u8; DIGIT_BATCH_SIZE];
for i in 0..(DIGIT_BATCH_SIZE) {
sete[i] = (i as u8) % 16;
}
assert_eq!(
hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(set8)),
Simd::from_array(sete)
);
}
#[test]
fn test_hex_byte() {
const HEX_BYTES_VALID: &[([u8; 2], u8)] = &[
(['f' as u8, 'f' as u8], 0xff),
(['0' as u8, '0' as u8], 0x00),
(['1' as u8, '1' as u8], 0x11),
(['e' as u8, 'f' as u8], 0xef),
(['f' as u8, 'e' as u8], 0xfe),
(['0' as u8, 'f' as u8], 0x0f),
(['f' as u8, '0' as u8], 0xf0),
];
for (hb, b) in HEX_BYTES_VALID {
assert_eq!(hex_byte(hb[0], hb[1]), Some(*b));
assert_eq!(HexByteDecoderA::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderA::decode_packed(hb), *b as u16);
}
const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['1' as u8, 'g' as u8],
['e' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
];
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert_ne!(
HexByteDecoderA::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(HexByteDecoderA::decode_packed(hb) & WIDE_INVALID_BIT, 0);
}
}
#[test]
fn test_hex_byte_simd() {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11",
*b"ef", *b"fe", *b"0f", *b"f0",
*b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f,
0xf0,
0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
];
let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID);
let bytes = Simd::from_array(BYTES_VALID);
if_trace_simd! {
println!("hex_bytes: {HEX_BYTES_VALID:02x?}");
println!("hex_bytes: {hex_bytes:02x?}");
println!("bytes: {BYTES_VALID:02x?}");
println!("bytes: {bytes:04x?}");
}
assert_eq!(HexByteDecoderA::decode_simd(hex_bytes), Some(bytes));
/*const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['1' as u8, 'g' as u8],
['e' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
];
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & WIDE_INVALID_BIT != 0);
}*/
}
macro_rules! test_f {
(boxed $f:ident) => {
test_f!(@ $f, Box::as_ref)
};
(@ $f:ident, $trans:expr) => {
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
result.as_ref().map($trans),
Some(bytes.as_bytes()),
"Sample {i} ({hex_bytes:?} => {bytes:?}) did not decode correctly (expected Some)"
);
}
for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
result.as_ref().map($trans),
None,
"Sample {i} ({hex_bytes:?}) did not decode correctly (expected None)"
);
}
};
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_iter_option() {
test_f!(boxed hex_bytes_dyn);
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_unsafe() {
test_f!(boxed hex_bytes_dyn_unsafe);
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_unsafe_iter() {
test_f!(boxed hex_bytes_dyn_unsafe_iter);
}
}