Compare commits

...

5 Commits

10 changed files with 1872 additions and 922 deletions

View File

@ -7,6 +7,7 @@ edition = "2021"
default = ["std"]
alloc = []
std = ["alloc"]
test = []
[dependencies]
@ -15,5 +16,11 @@ criterion = { version = "0.4", features = [ "real_blackbox" ] }
rand = "0.8.5"
[[bench]]
name = "bench"
name = "dec"
harness = false
required-features = ["test"]
[[bench]]
name = "enc"
harness = false
required-features = ["test"]

View File

@ -5,7 +5,9 @@
use std::mem::MaybeUninit;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use fast_hex::*;
use fast_hex::dec::*;
use fast_hex::simd::{DIGIT_BATCH_SIZE, WIDE_BATCH_SIZE};
use fast_hex::test::name;
const ASCII_BYTES: &[u8; 16] = b"Donald J. Trump!";
const HEX_BYTES: &[u8; ASCII_BYTES.len() * 2] = b"446F6E616C64204A2E205472756D7021";
@ -133,22 +135,6 @@ pub fn bench_256(c: &mut Criterion) {
benchmark_sized::<{ ASCII_BYTES_LONG.len() }, false>("256", HEX_BYTES_LONG, c);
}
const fn __make_hex_chars() -> [u8; 16] {
let mut chars = [0u8; 16];
let mut i = 0u8;
while (i as usize) < chars.len() {
chars[i as usize] = if i < 10 {
'0' as u8 + i
} else {
'a' as u8 + i - 10
};
i += 1;
}
chars
}
const HEX_CHARS: [u8; 16] = __make_hex_chars();
trait SliceRandom {
type Item;
@ -227,7 +213,7 @@ pub fn bench_2k(c: &mut Criterion) {
unsafe { std::mem::MaybeUninit::uninit().assume_init() };
let mut rng = rand::thread_rng();
for b in hex_bytes.iter_mut() {
*b = MaybeUninit::new(*HEX_CHARS.choose(&mut rng).unwrap());
*b = MaybeUninit::new(*fast_hex::enc::HEX_CHARS_LOWER.choose(&mut rng).unwrap());
}
let hex_bytes: [u8; LEN2] = unsafe { std::mem::transmute(hex_bytes) };
let bytes = match hex_bytes_dyn(hex_bytes.as_ref()) {
@ -250,7 +236,7 @@ pub fn bench_512k(c: &mut Criterion) {
unsafe { std::mem::transmute(Box::<[u8; LEN2]>::new_uninit()) };
let mut rng = rand::thread_rng();
for b in hex_bytes.iter_mut() {
*b = MaybeUninit::new(*HEX_CHARS.choose(&mut rng).unwrap());
*b = MaybeUninit::new(*fast_hex::enc::HEX_CHARS_LOWER.choose(&mut rng).unwrap());
}
let hex_bytes: Box<[u8; LEN2]> = unsafe { std::mem::transmute(hex_bytes) };
let bytes = match hex_bytes_dyn(hex_bytes.as_ref()) {
@ -273,7 +259,7 @@ pub fn bench_1_6m(c: &mut Criterion) {
unsafe { std::mem::transmute(Box::<[u8; LEN2]>::new_uninit()) };
let mut rng = rand::thread_rng();
for b in hex_bytes.iter_mut() {
*b = MaybeUninit::new(*HEX_CHARS.choose(&mut rng).unwrap());
*b = MaybeUninit::new(*fast_hex::enc::HEX_CHARS_LOWER.choose(&mut rng).unwrap());
}
let hex_bytes: Box<[u8; LEN2]> = unsafe { std::mem::transmute(hex_bytes) };
let bytes = match hex_bytes_dyn(hex_bytes.as_ref()) {
@ -359,7 +345,6 @@ pub fn bench_micro_hex_byte(c: &mut Criterion) {
});
bench_decoder::<HexByteDecoderA>(c, stringify!(HexByteDecoderA));
bench_decoder::<HexByteDecoderB>(c, stringify!(HexByteDecoderB));
}
pub fn bench_nano_hex_digit(c: &mut Criterion) {
@ -399,7 +384,12 @@ pub fn bench_nano_hex_byte(c: &mut Criterion) {
}
bench_decoder::<HexByteDecoderA>(c, stringify!(HexByteDecoderA));
bench_decoder::<HexByteDecoderB>(c, stringify!(HexByteDecoderB));
}
fn verification() {
fast_hex::simd::if_trace_simd! {
panic!("Illegal benchmark state: SIMD tracing enabled");
}
}
criterion_group!(
@ -412,4 +402,4 @@ criterion_group!(
);
criterion_group!(micro_benches, bench_micro_hex_digit, bench_micro_hex_byte);
criterion_group!(nano_benches, bench_nano_hex_digit, bench_nano_hex_byte);
criterion_main!(decode_benches, micro_benches, nano_benches);
criterion_main!(verification, decode_benches, micro_benches, nano_benches);

61
benches/enc.rs Normal file
View File

@ -0,0 +1,61 @@
#![feature(generic_const_exprs)]
#![feature(new_uninit)]
#![feature(portable_simd)]
use std::mem::MaybeUninit;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use fast_hex::enc::{Encode as _, Encoder};
use fast_hex::test::name;
type Enc = Encoder<true>;
const ASCII_BYTES: &[u8; 16] = b"Donald J. Trump!";
const HEX_BYTES: &[u8; ASCII_BYTES.len() * 2] = b"446F6E616C64204A2E205472756D7021";
const ASCII_BYTES_LONG: &[u8; 256] = b"Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!";
const HEX_BYTES_LONG: &[u8; ASCII_BYTES_LONG.len() * 2] = b"446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021";
fn benchmark_sized<const N: usize, const HEAP_ONLY: bool>(
name: &str,
bytes: &[u8; N],
c: &mut Criterion,
) where
[(); N * 2]:,
{
if !HEAP_ONLY {
c.bench_function(name!(name, "enc const"), |b| {
b.iter(|| Enc::enc_const(black_box(bytes)))
});
c.bench_function(name!(name, "enc sized"), |b| {
b.iter(|| Enc::enc_sized(black_box(bytes)))
});
}
c.bench_function(name!(name, "enc sized heap"), |b| {
b.iter(|| Enc::enc_sized_heap(black_box(bytes)))
});
benchmark(name, bytes, c);
}
fn benchmark(name: &str, bytes: &[u8], c: &mut Criterion) {
c.bench_function(name!(name, "enc slice"), |b| {
b.iter(|| Enc::enc_slice(black_box(bytes)))
});
}
pub fn bench_16(c: &mut Criterion) {
benchmark_sized::<{ ASCII_BYTES.len() }, false>("16", ASCII_BYTES, c);
}
pub fn bench_256(c: &mut Criterion) {
benchmark_sized::<{ ASCII_BYTES_LONG.len() }, false>("256", ASCII_BYTES_LONG, c);
}
fn verification() {
fast_hex::simd::if_trace_simd! {
panic!("Illegal benchmark state: SIMD tracing enabled");
}
}
criterion_group!(encode_benches, bench_16, bench_256,);
criterion_main!(verification, encode_benches);

631
src/dec.rs Normal file
View File

@ -0,0 +1,631 @@
//! SIMD-accelerated, validating hex decoding.
use core::mem::MaybeUninit;
use core::simd::*;
#[cfg(feature = "alloc")]
use alloc::{boxed::Box, vec::Vec};
use crate::prelude::*;
use simd::{DIGIT_BATCH_SIZE, GATHER_BATCH_SIZE, SIMD_WIDTH, WIDE_BATCH_SIZE};
const VALIDATE: bool = true;
pub const INVALID_BIT: u8 = 0b1000_0000;
pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000;
const ASCII_DIGITS: [u8; 256] = {
array_op!(
gen[256] | i | {
const DIGIT_MIN: u8 = '0' as u8;
const DIGIT_MAX: u8 = '9' as u8;
const LOWER_MIN: u8 = 'a' as u8;
const LOWER_MAX: u8 = 'f' as u8;
const UPPER_MIN: u8 = 'A' as u8;
const UPPER_MAX: u8 = 'F' as u8;
let i = i as u8;
match i {
DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN,
LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN,
UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN,
_ => INVALID_BIT,
}
}
)
};
const __ASCII_DIGITS_SIMD: [u32; 256] = util::cast_u8_u32(ASCII_DIGITS);
const ASCII_DIGITS_SIMD: *const i32 = &__ASCII_DIGITS_SIMD as *const u32 as *const i32;
/// Returns [`INVALID_BIT`] if invalid. Based on `char.to_digit()` in the stdlib.
#[inline]
pub const fn hex_digit(ascii: u8) -> u8 {
ASCII_DIGITS[ascii as usize]
}
#[inline(always)]
pub fn hex_digit_simd<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
unsafe {
Simd::gather_select_unchecked(
&ASCII_DIGITS,
Mask::splat(true),
ascii.cast(),
u8::splat_zero(),
)
}
}
/// Parses an ascii hex byte.
#[inline(always)]
pub const fn hex_byte(msb: u8, lsb: u8) -> Option<u8> {
let msb = hex_digit(msb);
let lsb = hex_digit(lsb);
// second is faster (perhaps it pipelines better?)
//if (msb | lsb) & INVALID_BIT != 0 {
if (msb & INVALID_BIT) | (lsb & INVALID_BIT) != 0 {
return None;
}
Some(msb << 4 | lsb)
}
/// A decoder for a single hex byte.
#[const_trait]
pub trait HexByteDecoder {
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_unpacked(hi: u8, lo: u8) -> u16;
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
Self::decode_unpacked(*hi, *lo)
}
}
/// A decoder for a sized batch of hex bytes.
pub trait HexByteSimdDecoder {
/// Parses an ascii hex byte. Any element of the return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
}
pub struct HexByteDecoderA;
impl const HexByteDecoder for HexByteDecoderA {
// util::defer_impl! {
// => HexByteDecoderA;
//
// fn decode_unpacked(hi: u8, lo: u8) -> u16;
//
// fn decode_packed(hi_lo: &[u8; 2]) -> u16;
// }
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let lo = hex_digit(lo) as u16;
let hi = hex_digit(hi) as u16;
// kind of bizarre: changing the order of these decreases perf by 6-12%
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let lo = hex_digit(*lo) as u16;
let hi = hex_digit(*hi) as u16;
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
}
macro_rules! hex_digits_simd_inline {
($ptr:ident) => {{
if_trace_simd! {
println!("hi_los: {:x?}", *$ptr);
}
unroll!(let [a: (0), b: (1), c: (2), d: (3)] => |i| *$ptr.add(i));
if_trace_simd! {
let f = |x| __ASCII_DIGITS_SIMD[x as usize];
println!(
"{:x?}, {:x?}, {:x?}, {:x?}",
a.map(f),
b.map(f),
c.map(f),
d.map(f)
);
}
unroll!(let [a, b, c, d] => |x| Simd::from_array(x));
unroll!(let [a, b, c, d] => |x| x.cast::<u32>());
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
unroll!(let [a, b, c, d] => |x| x.into());
unroll!(let [a, b, c, d] => |x| simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, x, 4));
unroll!(let [a, b, c, d] => |x| Simd::<u32, GATHER_BATCH_SIZE>::from(x).cast::<u8>());
if_trace_simd! {
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
// load the 64-bit integers into registers
unroll!(let [a, b, c, d] => |x| util::cast::<_, u64>(x).load_128());
if_trace_simd! {
unroll!(let [a, b, c, d] => |x| Simd::<u8, 16>::from(x));
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
// copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer)
let ab = simd::merge_lo_hi_m128(a, b);
// copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer)
let cd = simd::merge_lo_hi_m128(c, d);
if_trace_simd! {
unroll!(let [ab, cd] => |x| Simd::<u8, 16>::from(x));
println!("ab,cd: {ab:x?}, {cd:x?}");
}
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
let abcd = simd::merge_m128_m256(ab, cd);
if_trace_simd! {
let abcd: Simd<u8, DIGIT_BATCH_SIZE> = abcd.into();
println!("abcd: {abcd:x?}");
}
abcd
}};
}
macro_rules! merge_hex_digits_into_bytes_inline {
($hex_digits:ident) => {{
let msb = simd::extract_lo_bytes($hex_digits);
let lsb = simd::extract_hi_bytes($hex_digits);
let msb1: simd::arch::__m128i;
unsafe { std::arch::asm!("vpsllq {dst}, {src}, 4", src = in(xmm_reg) msb, dst = lateout(xmm_reg) msb1) };
if_trace_simd! {
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
println!("msb1: {msb1:x?}");
}
let msb2 = msb1.and(0xf0u8.splat().into());
let b = msb2.or(lsb);
if_trace_simd! {
unroll!(let [msb, msb1, msb2, lsb, b] => |x| Simd::<u8, WIDE_BATCH_SIZE>::from(x));
println!("| Packed | Msb | <<4 | & | Lsb | Bytes | |");
Simd::<u8, DIGIT_BATCH_SIZE>::from($hex_digits)
.to_array()
.chunks(2)
.zip(msb.to_array())
.zip(msb1.to_array())
.zip(msb2.to_array())
.zip(lsb.to_array())
.zip(b.to_array())
.for_each(|(((((chunk, msb), msb1), msb2), lsb), b)| {
println!(
"| {chunk:02x?} | {msb:x?} | {msb1:x?} | {msb2:x?} | {lsb:x?} | {b:02x?} | {ok} |",
chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16),
ok = if chunk[0] == msb && chunk[1] == lsb {
'✓'
} else {
'✗'
}
);
});
}
b
}};
}
impl HexByteSimdDecoder for HexByteDecoderA {
// util::defer_impl! {
// => HexByteDecoderA;
//
// fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
// }
#[inline(always)]
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
let hi_los = hi_los.as_ptr() as *const [u8; GATHER_BATCH_SIZE];
let hex_digits = unsafe { hex_digits_simd_inline!(hi_los) };
if hex_digits.test_and_non_zero(INVALID_BIT.splat().into()) {
return None;
}
Some(merge_hex_digits_into_bytes_inline!(hex_digits).into())
}
}
pub type HBD = HexByteDecoderA;
pub mod conv {
use crate::util;
#[inline(always)]
pub const fn u8x2_to_u8<const N_IN: usize>(a: [[u8; 2]; N_IN]) -> [u8; N_IN * 2] {
unsafe { util::cast(a) }
}
}
macro_rules! decode_hex_bytes_non_vectored {
($i:ident, $ascii:ident, $bytes:ident) => {{
//let mut bad = 0u16;
let mut bad = 0u8;
while $i < $ascii.len() {
/*let b = HBD::decode_packed(unsafe { &*($ascii.as_ptr().add($i) as *const [u8; 2]) });
bad |= b;
unsafe { *$bytes.get_unchecked_mut($i >> 1) = MaybeUninit::new(b as u8) };*/
let [hi, lo] = unsafe { *($ascii.as_ptr().add($i) as *const [u8; 2]) };
unroll!(let [lo, hi] => |x| hex_digit(x));
unroll!([(lo), (hi)] => |x| bad |= x);
/*if (hi & INVALID_BIT) | (lo & INVALID_BIT) != 0 {
println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
}*/
let b = (hi << 4) | lo;
unsafe { *$bytes.get_unchecked_mut($i >> 1) = MaybeUninit::new(b) };
$i += 2;
}
//if (bad & WIDE_INVALID_BIT) != 0 {
if (bad & INVALID_BIT) != 0 {
return false;
}
}};
}
#[inline(always)]
fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bool {
// these checks should always be eliminated because they are performed more efficiently
// (sometimes statically) in the callers, but they provide a major safeguard against nasty
// memory safety issues.
debug_assert_eq!(
ascii.len() >> 1 << 1,
ascii.len(),
"len of ascii is not a multiple of 2"
);
if ascii.len() >> 1 << 1 != ascii.len() {
return false;
}
debug_assert_eq!(
ascii.len() >> 1,
bytes.len(),
"len of ascii is not twice that of bytes"
);
if ascii.len() >> 1 != bytes.len() {
return false;
}
const VECTORED: bool = true;
if VECTORED {
let mut bad: arch::__m256i = u8::splat_zero().into();
let mut i = 0;
while i < util::align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let hex_digits = unsafe {
let hi_los = ascii.as_ptr().add(i) as *const [u8; GATHER_BATCH_SIZE];
hex_digits_simd_inline!(hi_los)
};
if VALIDATE {
unsafe {
core::arch::asm!("vpor {bad}, {digits}, {bad}", bad = inout(ymm_reg) bad, digits = in(ymm_reg) hex_digits, options(pure, nomem, preserves_flags, nostack));
}
}
let buf = merge_hex_digits_into_bytes_inline!(hex_digits);
unsafe {
// TODO: consider unrolling 2 iterations of this loop and buffering bytes in a single
// ymm register to be stored at once.
core::arch::asm!("vmovdqa [{}], {}", in(reg) bytes.as_mut_ptr().add(i >> 1) as *mut i8, in(xmm_reg) buf, options(preserves_flags, nostack));
};
i += DIGIT_BATCH_SIZE;
}
decode_hex_bytes_non_vectored!(i, ascii, bytes);
!bad.test_and_non_zero(INVALID_BIT.splat().into())
} else {
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes);
true
}
}
/// This function is a safe bet when you need to decode hex in a const context, on a system that
/// does not support AVX2, or just don't feel comfortable relying on so much unsafe code and inline
/// ASM.
///
/// It performs only 8% worse than the SIMD-accelerated implementation.
#[inline]
pub const fn hex_bytes_sized_const<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
if N == 0 {
Some([0u8; N])
} else {
let mut bytes = MaybeUninit::uninit_array();
let mut i = 0;
while i < N * 2 {
// Ensure bounds checks are removed. Might not be necessary.
if i >> 1 >= bytes.len() {
unsafe { core::hint::unreachable_unchecked() };
}
match hex_byte(unsafe { *ascii.get_unchecked(i) }, unsafe {
*ascii.get_unchecked(i + 1)
}) {
Some(b) => bytes[i >> 1] = MaybeUninit::new(b),
None => return None,
}
i += 2;
}
Some(unsafe { MaybeUninit::array_assume_init(bytes) })
}
}
#[inline]
pub fn hex_bytes_sized<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
if N == 0 {
Some([0u8; N])
} else {
let mut bytes = MaybeUninit::uninit_array();
if decode_hex_bytes_unchecked(ascii, &mut bytes) {
Some(unsafe { MaybeUninit::array_assume_init(bytes) })
} else {
None
}
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[u8; N]>> {
if N == 0 {
Some(Box::new([0u8; N]))
} else {
let mut bytes = unsafe { Box::<[_; N]>::new_uninit().assume_init() };
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8; N]) })
} else {
None
}
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn_unsafe(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
return None;
}
let mut bytes = Box::new_uninit_slice(len);
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::<[_]>::assume_init(bytes) })
} else {
None
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
return None;
}
let mut bytes = Box::<[u8]>::new_uninit_slice(len);
for (i, [hi, lo]) in ascii.array_chunks::<2>().enumerate() {
let lo = hex_digit(*lo);
let hi = hex_digit(*hi);
if (lo & INVALID_BIT) | (hi & INVALID_BIT) != 0 {
return None;
}
let b = (hi << 4) | lo;
unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b) };
}
Some(unsafe { Box::<[_]>::assume_init(bytes) })
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> {
let iter = ascii.array_chunks::<2>();
if iter.remainder().len() != 0 {
return None;
}
iter.map(|[msb, lsb]| hex_byte(*msb, *lsb))
.collect::<Option<Vec<u8>>>()
.map(|v| v.into_boxed_slice())
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::*;
#[test]
fn test_hex_digit() {
const HEX_DIGITS_LOWER: &[char; 16] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
];
const HEX_DIGITS_UPPER: &[char; 16] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F',
];
for (i, digit) in HEX_DIGITS_LOWER.into_iter().enumerate() {
assert_eq!(hex_digit(*digit as u8), i as u8);
}
for (i, digit) in HEX_DIGITS_UPPER.into_iter().enumerate() {
assert_eq!(hex_digit(*digit as u8), i as u8);
}
}
#[test]
fn test_hex_digit_simd() {
const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0',
'1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E',
'F',
// '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
];
let mut set8 = [0u8; DIGIT_BATCH_SIZE];
for (c, b) in HEX_DIGITS.iter().zip(set8.iter_mut()) {
*b = *c as u8;
}
let mut sete = [0u8; DIGIT_BATCH_SIZE];
for i in 0..(DIGIT_BATCH_SIZE) {
sete[i] = (i as u8) % 16;
}
assert_eq!(
hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(set8)),
Simd::from_array(sete)
);
}
#[test]
fn test_hex_byte() {
const HEX_BYTES_VALID: &[([u8; 2], u8)] = &[
(['f' as u8, 'f' as u8], 0xff),
(['0' as u8, '0' as u8], 0x00),
(['1' as u8, '1' as u8], 0x11),
(['e' as u8, 'f' as u8], 0xef),
(['f' as u8, 'e' as u8], 0xfe),
(['0' as u8, 'f' as u8], 0x0f),
(['f' as u8, '0' as u8], 0xf0),
];
for (hb, b) in HEX_BYTES_VALID {
assert_eq!(hex_byte(hb[0], hb[1]), Some(*b));
assert_eq!(HexByteDecoderA::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderA::decode_packed(hb), *b as u16);
}
const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['1' as u8, 'g' as u8],
['e' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
];
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert_ne!(
HexByteDecoderA::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(HexByteDecoderA::decode_packed(hb) & WIDE_INVALID_BIT, 0);
}
}
#[test]
fn test_hex_byte_simd() {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11",
*b"ef", *b"fe", *b"0f", *b"f0",
*b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f,
0xf0,
0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
];
let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID);
let bytes = Simd::from_array(BYTES_VALID);
if_trace_simd! {
println!("hex_bytes: {HEX_BYTES_VALID:02x?}");
println!("hex_bytes: {hex_bytes:02x?}");
println!("bytes: {BYTES_VALID:02x?}");
println!("bytes: {bytes:04x?}");
}
assert_eq!(HexByteDecoderA::decode_simd(hex_bytes), Some(bytes));
/*const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['1' as u8, 'g' as u8],
['e' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
];
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & WIDE_INVALID_BIT != 0);
}*/
}
macro_rules! test_f {
(boxed $f:ident) => {
test_f!(@ $f, Box::as_ref)
};
(@ $f:ident, $trans:expr) => {
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
result.as_ref().map($trans),
Some(bytes.as_bytes()),
"Sample {i} ({hex_bytes:?} => {bytes:?}) did not decode correctly (expected Some)"
);
}
for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
result.as_ref().map($trans),
None,
"Sample {i} ({hex_bytes:?}) did not decode correctly (expected None)"
);
}
};
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_iter_option() {
test_f!(boxed hex_bytes_dyn);
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_unsafe() {
test_f!(boxed hex_bytes_dyn_unsafe);
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_unsafe_iter() {
test_f!(boxed hex_bytes_dyn_unsafe_iter);
}
}

548
src/enc.rs Normal file
View File

@ -0,0 +1,548 @@
//! SIMD-accelerated hex encoding.
use std::mem::MaybeUninit;
use std::simd::*;
use crate::prelude::*;
use simd::{DIGIT_BATCH_SIZE, GATHER_BATCH_SIZE, SIMD_WIDTH, WIDE_BATCH_SIZE};
const REQUIRED_ALIGNMENT: usize = 64;
pub const HEX_CHARS_LOWER: [u8; 16] = array_op!(map[16, ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f']] |_, c| c as u8);
pub const HEX_CHARS_UPPER: [u8; 16] =
array_op!(map[16, HEX_CHARS_LOWER] |_, c| (c as char).to_ascii_uppercase() as u8);
const __HEX_CHARS_LOWER_SIMD: [u32; 16] = util::cast_u8_u32(HEX_CHARS_LOWER);
const HEX_CHARS_LOWER_SIMD: *const i32 = &__HEX_CHARS_LOWER_SIMD as *const u32 as *const i32;
const __HEX_CHARS_UPPER_SIMD: [u32; 16] = util::cast_u8_u32(HEX_CHARS_UPPER);
const HEX_CHARS_UPPER_SIMD: *const i32 = &__HEX_CHARS_UPPER_SIMD as *const u32 as *const i32;
// TODO: add a check for endianness (current is assumed LE)
const HEX_BYTES_LOWER: [u16; 256] = array_op!(gen[256] |i| ((HEX_CHARS_LOWER[(i & 0xf0) >> 4] as u16)) | ((HEX_CHARS_LOWER[i & 0x0f] as u16) << 8));
const HEX_BYTES_UPPER: [u16; 256] = array_op!(gen[256] |i| ((HEX_CHARS_UPPER[(i & 0xf0) >> 4] as u16)) | ((HEX_CHARS_UPPER[i & 0x0f] as u16) << 8));
macro_rules! select {
($cond:ident ? $true:ident : $false:ident) => {
if $cond {
$true
} else {
$false
}
};
(($cond:expr) ? ($true:expr) : ($false:expr)) => {
if $cond {
$true
} else {
$false
}
};
}
#[inline(always)]
fn nbl_to_ascii<const UPPER: bool>(nbl: u8) -> u8 {
// fourth bit set if true
let at_least_10 = {
let b1 = nbl & 0b1010;
let b2 = nbl & 0b1100;
((nbl >> 1) | (b2 & (b2 << 1)) | (b1 & (b1 << 2))) & 0b1000
};
// 6th bit is always 1 with a-z and 0-9
let b6_val = if UPPER { (at_least_10 ^ 0b1000) << 2 } else { 0b100000 };
// 5th bit is always 1 with 0-9
let b5_val = (at_least_10 ^ 0b1000) << 1;
// 7th bit is always 1 with a-z and A-Z
let b7_val = at_least_10 << 3;
// fill all bits with the value of the 4th bit
let is_at_least_10_all_mask = (((at_least_10 << 4) as i8) >> 7) as u8;
// sub 9 if we're >=10
// a-z and A-Z start at ..0001 rather than ..0000 like 0-9, so we sub 9, not 10
let sub = 9 & is_at_least_10_all_mask;
// apply the sub, then OR in the constants
(nbl - sub) | b6_val | b5_val | b7_val
}
#[inline(always)]
fn nbl_wide_to_ascii<const UPPER: bool>(nbl: u16) -> u16 {
// fourth bit set if true
let at_least_10 = {
let b1 = nbl & 0b1010;
let b2 = nbl & 0b1100;
((nbl >> 1) | (b2 & (b2 << 1)) | (b1 & (b1 << 2))) & 0b1000
};
// mask used don the 6th bit.
let b6_val = if UPPER { (at_least_10 ^ 0b1000) << 2 } else { 0b100000 };
let b5_val = (at_least_10 ^ 0b1000) << 1;
let b7_val = at_least_10 << 3;
// sign extend the 1 if set
let is_at_least_10_all_mask = (((at_least_10 << 12) as i16) >> 15) as u16;
let sub = 9 & is_at_least_10_all_mask;
let c = (nbl - sub) | b6_val | b5_val | b7_val;
c
}
// the way this is used, is by inserting the u16 directly into a byte array, so on a little-endian system (assumed in the code), we need the low byte shifted to the left, which seems counterintuitive.
#[inline(always)]
fn byte_to_ascii<const UPPER: bool>(byte: u8) -> u16 {
//let byte = byte as u16;
//nbl_wide_to_ascii::<UPPER>((byte & 0xf0) >> 4) | (nbl_wide_to_ascii::<UPPER>(byte & 0x0f) << 8)
(nbl_to_ascii::<UPPER>((byte & 0xf0) >> 4) as u16) | ((nbl_to_ascii::<UPPER>(byte & 0x0f) as u16) << 8)
}
macro_rules! const_impl1 {
($UPPER:ident, $src:ident, $dst:ident) => {{
let mut i = 0;
const UNROLL: usize = 8;
let ub = $dst.len();
let aub = util::align_down_to::<{ 2 * UNROLL }>(ub);
let mut src = $src.as_ptr();
let mut dst = $dst.as_mut_ptr() as *mut u8;
while i < aub {
unsafe {
let [b1, b2, b3, b4, b5, b6, b7, b8] = [(); UNROLL];
unroll!(let [b1, b2, b3, b4, b5, b6, b7, b8] => |_| {
let b = *src;
src = src.add(1);
b
});
unroll!([(0, b1), (2, b2), (4, b3), (6, b4), (8, b5), (10, b6), (12, b7), (14, b8)] => |i, b| {
*dst.add(i) = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b >> 4) as usize)
});
unroll!([(0, b1), (2, b2), (4, b3), (6, b4), (8, b5), (10, b6), (12, b7), (14, b8)] => |i, b| {
*dst.add(i + 1) = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b & 0x0f) as usize)
});
dst = dst.add(2 * UNROLL);
i += 2 * UNROLL;
}
}
while i < ub {
unsafe {
let b = *src;
*dst = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b >> 4) as usize);
dst = dst.add(1);
*dst = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b & 0x0f) as usize);
dst = dst.add(1);
i += 2;
src = src.add(1);
}
}
}};
}
#[inline(always)]
fn u64_to_ne_u16(v: u64) -> [u16; 4] {
unsafe { std::mem::transmute(v.to_ne_bytes()) }
}
macro_rules! const_impl {
($UPPER:ident, $src:ident, $dst:ident) => {{
let mut i = 0;
const UNROLL: usize = 8;
let ub = $src.len();
let aub = util::align_down_to::<{ UNROLL }>(ub);
let mut src = $src.as_ptr() as *const u64;
//let mut dst = $dst.as_mut_ptr() as *mut u16;
let mut dst = $dst.as_mut_ptr() as *mut u16;
while i < aub {
unsafe {
let [b1, b2, b3, b4, b5, b6, b7, b8] = (src.read_unaligned()).to_ne_bytes();
src = src.add(1);
//let [b1, b2, b3, b4, b5, b6, b7, b8] = [(); UNROLL];
//unroll!(let [b1, b2, b3, b4, b5, b6, b7, b8] => |_| {
// let b = *src;
// src = src.add(1);
// b
//});
//let mut buf1: u64 = 0;
//let mut buf2: u64 = 0;
unroll!(let [b1, b2, b3, b4, b5, b6, b7, b8] => |b| {
*select!($UPPER ? HEX_BYTES_UPPER : HEX_BYTES_LOWER).get_unchecked(b as usize)
});
//unroll!(let [b1: (0, b1), b2: (1, b2), b3: (2, b3), b4: (3, b4), b5: (4, b5), b6: (5, b6), b7: (6, b7), b8: (7, b8)] => |i, v| {
// if i < 4 {
// (v as u64) << (i * 16)
// } else {
// (v as u64) << ((i - 4) * 16)
// }
//});
unroll!([(0, b1), (1, b2), (2, b3), (3, b4), (4, b5), (5, b6), (6, b7), (7, b8)] => |_, v| {
//*dst = *select!($UPPER ? HEX_BYTES_UPPER : HEX_BYTES_LOWER).get_unchecked(b as usize);
*dst = v;
//if i < 4 {
// //println!("[{i}] {v:064b}");
// buf1 |= v;
//} else {
// //println!("[{i}] {v:064b}");
// buf2 |= v;
//}
// if i < 4 {
// buf1[i] = MaybeUninit::new(v);
// } else {
// buf2[i - 4] = MaybeUninit::new(v);
// }
//*dst = byte_to_ascii::<$UPPER>(b);
dst = dst.add(1);
});
// TODO: would using vector store actually be faster here (particularly for the
// heap variant)
//assert!(dst < ($dst.as_mut_ptr() as *mut u64).add($dst.len()));
//*dst = buf1;
//dst = dst.add(1);
//assert!(dst < ($dst.as_mut_ptr() as *mut u64).add($dst.len()));
//*dst = buf2;
//dst = dst.add(1);
i += UNROLL;
}
}
let mut src = src as *const u8;
let mut dst = dst as *mut u16;
while i < ub {
unsafe {
let b = *src;
*dst = *select!($UPPER ? HEX_BYTES_UPPER : HEX_BYTES_LOWER).get_unchecked(b as usize);
//*dst = byte_to_ascii::<$UPPER>(b);
dst = dst.add(1);
src = src.add(1);
i += 1;
}
}
}};
}
/// The `$dst` must be 32-byte aligned.
macro_rules! common_impl {
($UPPER:ident, $src:ident, $dst:ident) => {
const_impl!($UPPER, $src, $dst)
};
(@disabled $UPPER:ident, $src:ident, $dst:ident) => {{
let mut i = 0;
let ub = $dst.len();
let aub = util::align_down_to::<DIGIT_BATCH_SIZE>(ub);
let mut src = $src.as_ptr();
let mut dst = $dst.as_mut_ptr();
while i < aub {
unsafe {
//let hi_los = $src.as_ptr().add(i) as *const [u8; GATHER_BATCH_SIZE];
//let chunk = $src.as_ptr().add(i >> 1) as *const [u8; WIDE_BATCH_SIZE];
//let chunk = *chunk;
//let chunk: simd::arch::__m128i = Simd::from_array(chunk).into();
let chunk: simd::arch::__m128i;
std::arch::asm!("vmovdqu {dst}, [{src}]", src = in(reg) src, dst = lateout(xmm_reg) chunk);
let hi = chunk.and(0xf0u8.splat().into());
let hi: simd::arch::__m128i = simd::shr_64!(4, (xmm_reg) hi);
let lo = chunk.and(0x0fu8.splat().into());
unroll!(let [hi, lo] => |x| Simd::<u8, WIDE_BATCH_SIZE>::from(x));
if_trace_simd! {
println!("hi,lo: {hi:02x?}, {lo:02x?}");
}
// TODO: find a more efficient approach
let hi = hi.cast::<u32>();
let lo = lo.cast::<u32>();
// just trunc these
let a: simd::arch::__m256i = util::cast(hi);
let c: simd::arch::__m256i = util::cast(lo);
// need to shift these over
unroll!(let [hi, lo] => |x| (&x as *const _ as *const [u32; 8]).add(1));
let b: simd::arch::__m256i = Simd::from_array(*hi).into();
let d: simd::arch::__m256i = Simd::from_array(*lo).into();
//unroll!(let [hi, lo] => |x| util::cast::<_, simd::arch::__m256i>(x));
//if_trace_simd! {
// unroll!(let [hi, lo] => |x| Simd::<u8, DIGIT_BATCH_SIZE>::from(x));
// println!("hi,lo: {hi:02x?}, {lo:02x?}");
//}
//let a = hi;
//let b: simd::arch::__m128i;
//std::arch::asm!("vpermq {:x}, {:y}, 1", lateout(xmm_reg) b, in(xmm_reg) hi);
//let b: simd::arch::__m128i;
//std::arch::asm!("vextracti128 {:x}, {:y}, 1", lateout(xmm_reg) b, in(xmm_reg) hi);
//let c = lo;
//let d: simd::arch::__m128i;
//std::arch::asm!("vextracti128 {:x}, {:y}, 1", lateout(xmm_reg) d, in(xmm_reg) lo);
//unroll!(let [a, b, c, d] => |x| {
// let o: simd::arch::__m256i;
// std::arch::asm!("vpmovzxdq {}, {}", lateout(ymm_reg) o, in(xmm_reg) x);
// o
//});
//let a = hi.widen::<4, 0, false>();
//let b = hi.widen::<4, 8, false>();
//let c = lo.widen::<4, 0, false>();
//let d = lo.widen::<4, 8, false>();
if_trace_simd! {
unroll!(let [a, b, c, d] => |x| Simd::<u32, GATHER_BATCH_SIZE>::from(x));
println!("a,b,c,d: {a:02x?}, {b:02x?}, {c:02x?}, {d:02x?}");
}
unroll!(let [a, b, c, d] => |x| simd::arch::__m256i::from(x));
// let indices: simd::arch::__m256i = Simd::from_array([]).into();
// std::arch::asm!("vpshufb {out:y}, {in:y}, {indices}", out = lateout(ymm_reg) a, in = in(ymm_reg) hi, indices = in(ymm_reg) );
unroll!(let [a, b, c, d] => |x| simd::arch::_mm256_i32gather_epi32(select!($UPPER ? HEX_CHARS_UPPER_SIMD : HEX_CHARS_LOWER_SIMD), x, 4));
unroll!(let [a, b, c, d] => |x| Simd::<u32, GATHER_BATCH_SIZE>::from(x).cast::<u8>());
if_trace_simd! {
println!("a,b,c,d: {a:02x?}, {b:02x?}, {c:02x?}, {d:02x?}");
}
// load the 64-bit integers into registers
unroll!(let [a, b, c, d] => |x| util::cast::<_, u64>(x).load_128());
if_trace_simd! {
unroll!(let [a, b, c, d] => |x| Simd::<u8, 16>::from(x));
println!("a,b,c,d: {a:02x?}, {b:02x?}, {c:02x?}, {d:02x?}");
}
// copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer)
let ab = simd::merge_lo_hi_m128(a, b);
// copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer)
let cd = simd::merge_lo_hi_m128(c, d);
if_trace_simd! {
unroll!(let [ab, cd] => |x| Simd::<u8, 16>::from(x));
println!("ab,cd: {ab:x?}, {cd:x?}");
}
let ab1: simd::arch::__m256i;
let cd1: simd::arch::__m256i;
std::arch::asm!("vpunpcklbw {:y}, {:y}, {:y}", lateout(ymm_reg) ab1, in(ymm_reg) ab, in(ymm_reg) cd);
std::arch::asm!("vpunpckhbw {:y}, {:y}, {:y}", lateout(ymm_reg) cd1, in(ymm_reg) ab, in(ymm_reg) cd);
//let abcd = simd::merge_m128_m256(util::cast(ab1), util::cast(cd1));
//let abcd = ab1;
let abcd: simd::arch::__m256i;
core::arch::asm!("vinserti128 {:y}, {:y}, {:x}, 0x1", lateout(ymm_reg) abcd, in(ymm_reg) ab1, in(ymm_reg) cd1, options(pure, nomem, preserves_flags, nostack));
//core::arch::asm!("vinserti128 {:y}, {:y}, {:x}, 0x1", lateout(ymm_reg) abcd, in(ymm_reg) ab1, in(ymm_reg) cd1);
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
//let abcd = simd::merge_m128_m256(ab, cd);
if_trace_simd! {
let abcd: Simd<u8, DIGIT_BATCH_SIZE> = abcd.into();
println!("abcd: {abcd:x?}");
}
// HA! there's an undocumented requirement for the dest to be 32-byte aligned.
//assert_eq!((ptr.cast::<u8>() as usize) & 32 - 1, 0);
core::arch::asm!("vmovdqa [{}], {}", in(reg) dst as *mut i8, in(ymm_reg) abcd, options(preserves_flags, nostack));
dst = dst.add(DIGIT_BATCH_SIZE);
i += DIGIT_BATCH_SIZE;
src = src.add(WIDE_BATCH_SIZE);
}
}
while i < ub {
unsafe {
let b = *src;
*dst = MaybeUninit::new(*select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b >> 4) as usize));
dst = dst.add(1);
*dst = MaybeUninit::new(*select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b & 0x0f) as usize));
dst = dst.add(1);
i += 2;
src = src.add(1);
}
}
if_trace_simd! {
let slice: &[_] = $dst.as_ref();
match std::str::from_utf8(unsafe { &*(slice as *const [_] as *const [u8]) }) {
Ok(s) => {
println!("encoded: {s:?}");
}
Err(e) => {
println!("encoded corrupted utf8: {e}");
}
}
}
}};
}
macro_rules! define_encode_str {
($name:ident$(<$N:ident>)?($in:ty) $(where $( $where:tt )+)?) => {
fn $name$(<const $N: usize>)?(src: $in) -> String $( where $( $where )+ )?;
};
}
macro_rules! impl_encode_str {
($name:ident$(<$N:ident>)?($in:ty) => $impl:ident (|$bytes:ident| $into_vec:expr) $(where $( $where:tt )+)?) => {
#[inline]
fn $name$(<const $N: usize>)?(src: $in) -> String $( where $( $where )+ )? {
let $bytes = Self::$impl(src);
unsafe { String::from_utf8_unchecked($into_vec) }
}
};
}
pub trait Encode {
/// Encodes the sized input on the stack.
fn enc_sized<const N: usize>(src: &[u8; N]) -> [u8; N * 2]
where
[u8; N * 2]:;
/// Encodes the sized input on the heap.
fn enc_sized_heap<const N: usize>(src: &[u8; N]) -> Box<[u8; N * 2]>
where
[u8; N * 2]:;
/// Encodes the unsized input on the heap.
fn enc_slice(src: &[u8]) -> Box<[u8]>;
define_encode_str!(enc_str_sized<N>(&[u8; N]) where [u8; N * 2]:);
define_encode_str!(enc_str_sized_heap<N>(&[u8; N]) where [u8; N * 2]:);
define_encode_str!(enc_str_slice(&[u8]));
}
pub struct Encoder<const UPPER: bool = false>;
#[repr(align(32))]
struct Aligned32<T>(T);
impl<const UPPER: bool> Encode for Encoder<UPPER> {
#[inline]
fn enc_sized<const N: usize>(src: &[u8; N]) -> [u8; N * 2]
where
[u8; N * 2]:,
{
// SAFETY: `Aligned32` has no initialization in and of itself, nor does an array of `MaybeUninit`
let mut buf =
unsafe { MaybeUninit::<Aligned32<[MaybeUninit<_>; N * 2]>>::uninit().assume_init() };
let buf1 = &mut buf.0;
common_impl!(UPPER, src, buf1);
unsafe { MaybeUninit::array_assume_init(buf.0) }
}
#[inline]
fn enc_sized_heap<const N: usize>(src: &[u8; N]) -> Box<[u8; N * 2]>
where
[u8; N * 2]:,
{
let mut buf: Box<[MaybeUninit<u8>; N * 2]> =
unsafe { util::alloc_aligned_box::<_, REQUIRED_ALIGNMENT>() };
common_impl!(UPPER, src, buf);
unsafe { Box::from_raw(Box::into_raw(buf).cast()) }
}
#[inline]
fn enc_slice(src: &[u8]) -> Box<[u8]> {
let mut buf: Box<[MaybeUninit<u8>]> =
unsafe { util::alloc_aligned_box_slice::<_, REQUIRED_ALIGNMENT>(src.len() * 2) };
common_impl!(UPPER, src, buf);
unsafe { Box::<[_]>::assume_init(buf) }
}
impl_encode_str!(enc_str_sized<N>(&[u8; N]) => enc_sized (|bytes| bytes.into()) where [u8; N * 2]:);
impl_encode_str!(enc_str_sized_heap<N>(&[u8; N]) => enc_sized_heap (|bytes| {
Vec::from_raw_parts(Box::into_raw(bytes) as *mut u8, N * 2, N * 2)
}) where [u8; N * 2]:);
impl_encode_str!(enc_str_slice(&[u8]) => enc_slice (|bytes| Vec::from(bytes)));
}
impl<const UPPER: bool> Encoder<UPPER> {
#[inline]
pub fn enc_const<const N: usize>(mut src: &[u8; N]) -> [u8; N * 2]
where
[u8; N * 2]:,
{
let mut buf = MaybeUninit::uninit_array();
const_impl!(UPPER, src, buf);
unsafe { MaybeUninit::array_assume_init(buf) }
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::test::*;
#[test]
fn test_nbl_to_ascii() {
for i in 0..16 {
let a = nbl_to_ascii::<false>(i);
let b = HEX_CHARS_LOWER[i as usize];
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
let a = nbl_to_ascii::<true>(i);
let b = HEX_CHARS_UPPER[i as usize];
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
}
}
#[test]
fn test_nbl_wide_to_ascii() {
for i in 0..16 {
let a = nbl_wide_to_ascii::<false>(i);
let b = HEX_CHARS_LOWER[i as usize] as u16;
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
let a = nbl_wide_to_ascii::<true>(i);
let b = HEX_CHARS_UPPER[i as usize] as u16;
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
}
}
#[test]
fn test_byte_to_ascii() {
for i in 0..=255 {
let a = byte_to_ascii::<false>(i);
let b = HEX_BYTES_LOWER[i as usize];
assert_eq!(a, b, "({i}) {a:016b} != {b:016b}");
let a = byte_to_ascii::<true>(i);
let b = HEX_BYTES_UPPER[i as usize];
assert_eq!(a, b, "({i}) {a:016b} != {b:016b}");
}
}
macro_rules! for_each_sample {
($name:ident, |$ss:pat_param, $shs:pat_param, $sb:pat_param, $shb:pat_param| $expr:expr) => {
#[test]
fn $name() {
let $ss = STR;
let $shs = HEX_STR;
let $sb = BYTES;
let $shb = HEX_BYTES;
$expr;
let $ss = LONG_STR;
let $shs = LONG_HEX_STR;
let $sb = LONG_BYTES;
let $shb = LONG_HEX_BYTES;
$expr;
}
};
}
type Enc = Encoder<true>;
for_each_sample!(enc_const, |_, _, b, hb| assert_eq!(Enc::enc_const(b), *hb));
for_each_sample!(enc_sized, |_, _, b, hb| assert_eq!(Enc::enc_sized(b), *hb));
for_each_sample!(enc_sized_heap, |_, _, b, hb| assert_eq!(
Enc::enc_sized_heap(b),
Box::new(*hb)
));
for_each_sample!(enc_slice, |_, _, b, hb| assert_eq!(
Enc::enc_slice(b),
(*hb).into_iter().collect::<Vec<_>>().into_boxed_slice()
));
for_each_sample!(enc_str_sized, |_, hs, b, _| assert_eq!(Enc::enc_str_sized(b), hs.to_owned()));
for_each_sample!(enc_str_sized_heap, |_, hs, b, _| assert_eq!(
Enc::enc_str_sized_heap(b),
hs.to_owned()
));
for_each_sample!(enc_str_slice, |_, hs, b, _| assert_eq!(
Enc::enc_str_slice(b),
hs.to_owned()
));
}

View File

@ -1,5 +1,7 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![feature(array_chunks)]
#![feature(core_intrinsics)]
#![feature(const_eval_select)]
#![feature(const_slice_index)]
#![feature(const_trait_impl)]
#![feature(extend_one)]
@ -11,812 +13,22 @@
#![feature(const_maybe_uninit_array_assume_init)]
#![feature(const_maybe_uninit_uninit_array)]
#![cfg_attr(feature = "alloc", feature(new_uninit))]
#![feature(stdsimd)]
#![feature(portable_simd)]
pub(crate) mod util;
pub(crate) mod simd;
// ignores warning about `generic_const_exprs`
#![allow(incomplete_features)]
#[cfg(feature = "alloc")]
extern crate alloc;
use core::mem::MaybeUninit;
use core::simd::*;
pub mod simd;
pub(crate) mod util;
#[cfg(feature = "alloc")]
use alloc::{boxed::Box, vec::Vec};
#[doc(hidden)]
#[cfg(any(test, feature = "test"))]
pub mod test;
use simd::SimdTestAnd as _;
use simd::SimdBitwise as _;
pub mod dec;
pub mod enc;
use util::array_op;
// use the maximum batch size that would be supported by AVX-512
//pub const SIMD_WIDTH: usize = 512;
pub const SIMD_WIDTH: usize = 256;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
/// The batch size used for the hex digits.
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
macro_rules! if_trace_simd {
($( $tt:tt )*) => {
// disabled
//{ $( $tt )* }
};
}
const VALIDATE: bool = true;
#[inline]
const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] {
if first_bias {
array_op!(gen[N] |i| i * 2)
} else {
array_op!(gen[N] |i| i * 2 + 1)
}
}
#[inline]
const fn cast_u8_u32<const N: usize>(arr: [u8; N]) -> [u32; N] {
array_op!(map[N, arr] |_, v| v as u32)
}
const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true);
const LSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(false);
pub const INVALID_BIT: u8 = 0b1000_0000;
pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000;
const ASCII_DIGITS: [u8; 256] = {
array_op!(gen[256] |i| {
const DIGIT_MIN: u8 = '0' as u8;
const DIGIT_MAX: u8 = '9' as u8;
const LOWER_MIN: u8 = 'a' as u8;
const LOWER_MAX: u8 = 'f' as u8;
const UPPER_MIN: u8 = 'A' as u8;
const UPPER_MAX: u8 = 'F' as u8;
let i = i as u8;
match i {
DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN,
LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN,
UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN,
_ => INVALID_BIT,
}
})
};
const __ASCII_DIGITS_SIMD: [u32; 256] = cast_u8_u32(ASCII_DIGITS);
const ASCII_DIGITS_SIMD: *const i32 = &__ASCII_DIGITS_SIMD as *const u32 as *const i32;
/// Returns [`INVALID_BIT`] if invalid. Based on `char.to_digit()` in the stdlib.
#[inline]
pub const fn hex_digit(ascii: u8) -> u8 {
ASCII_DIGITS[ascii as usize]
}
#[inline(always)]
pub fn hex_digit_simd<const LANES: usize>(ascii: Simd<u8, LANES>) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
unsafe {
Simd::gather_select_unchecked(
&ASCII_DIGITS,
Mask::splat(true),
ascii.cast(),
simd::splat_0u8::<LANES>(),
)
}
}
/// Parses an ascii hex byte.
#[inline(always)]
pub const fn hex_byte(msb: u8, lsb: u8) -> Option<u8> {
let msb = hex_digit(msb);
let lsb = hex_digit(lsb);
// second is faster (perhaps it pipelines better?)
//if (msb | lsb) & INVALID_BIT != 0 {
if (msb & INVALID_BIT) | (lsb & INVALID_BIT) != 0 {
return None;
}
Some(msb << 4 | lsb)
}
/// A decoder for a single hex byte.
#[const_trait]
pub trait HexByteDecoder {
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_unpacked(hi: u8, lo: u8) -> u16;
/// Parses an ascii hex byte. Any return value exceeding [`u8::MAX`] indicates invalid input.
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
Self::decode_unpacked(*hi, *lo)
}
}
/// A decoder for a sized batch of hex bytes.
pub trait HexByteSimdDecoder {
/// Parses an ascii hex byte. Any element of the return value exceeding [`u8::MAX`] indicates invalid input.
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
}
pub struct HexByteDecoderA;
impl const HexByteDecoder for HexByteDecoderA {
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let hi = hex_digit(hi) as u16;
let lo = hex_digit(lo) as u16;
// might these these masks allow the ORs the be pipelined more efficiently?
(hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let hi = hex_digit(*hi) as u16;
let lo = hex_digit(*lo) as u16;
(hi << 4) | (lo & 0xf) | ((lo & 0xf0) << 8)
}
}
impl HexByteSimdDecoder for HexByteDecoderA {
#[inline(always)]
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
let hex_digits = hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(hi_los));
if ((hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
.simd_ne(simd::splat_0u8::<DIGIT_BATCH_SIZE>()))
.any()
{
return None;
}
let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
Some((msb << simd::splat_n::<WIDE_BATCH_SIZE>(4)) | lsb)
}
}
pub struct HexByteDecoderB;
impl const HexByteDecoder for HexByteDecoderB {
util::defer_impl! {
=> HexByteDecoderA;
//fn decode_unpacked(hi: u8, lo: u8) -> u16;
//fn decode_packed(hi_lo: &[u8; 2]) -> u16;
}
#[inline(always)]
fn decode_unpacked(hi: u8, lo: u8) -> u16 {
let lo = hex_digit(lo) as u16;
let hi = hex_digit(hi) as u16;
// kind of bizarre: changing the order of these decreases perf by 6-12%
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
#[inline(always)]
fn decode_packed([hi, lo]: &[u8; 2]) -> u16 {
let lo = hex_digit(*lo) as u16;
let hi = hex_digit(*hi) as u16;
(hi << 4) | lo | ((lo & INVALID_BIT as u16) << 8)
}
}
macro_rules! hex_digits_simd_inline {
($ptr:ident) => {{
if_trace_simd! {
println!("hi_los: {:x?}", *$ptr);
}
let a = *$ptr;
let b = *$ptr.add(1);
let c = *$ptr.add(2);
let d = *$ptr.add(3);
if_trace_simd! {
let f = |x| __ASCII_DIGITS_SIMD[x as usize];
println!(
"{:x?}, {:x?}, {:x?}, {:x?}",
a.map(f),
b.map(f),
c.map(f),
d.map(f)
);
}
let a = Simd::from_array(a);
let b = Simd::from_array(b);
let c = Simd::from_array(c);
let d = Simd::from_array(d);
let a = a.cast::<u32>();
let b = b.cast::<u32>();
let c = c.cast::<u32>();
let d = d.cast::<u32>();
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
let a = a.into();
let b = b.into();
let c = c.into();
let d = d.into();
let a = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, a, 4);
let b = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, b, 4);
let c = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, c, 4);
let d = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, d, 4);
let a = Simd::<u32, GATHER_BATCH_SIZE>::from(a).cast::<u8>();
let b = Simd::<u32, GATHER_BATCH_SIZE>::from(b).cast::<u8>();
let c = Simd::<u32, GATHER_BATCH_SIZE>::from(c).cast::<u8>();
let d = Simd::<u32, GATHER_BATCH_SIZE>::from(d).cast::<u8>();
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
// load the 64-bit integers into registers
let a = simd::load_u64_m128(util::cast(a));
let b = simd::load_u64_m128(util::cast(b));
let c = simd::load_u64_m128(util::cast(c));
let d = simd::load_u64_m128(util::cast(d));
if_trace_simd! {
let a = Simd::<u8, 16>::from(a);
let b = Simd::<u8, 16>::from(b);
let c = Simd::<u8, 16>::from(c);
let d = Simd::<u8, 16>::from(d);
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
// copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer)
let ab = simd::merge_lo_hi_m128(a, b);
// copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer)
let cd = simd::merge_lo_hi_m128(c, d);
if_trace_simd! {
let ab = Simd::<u8, 16>::from(ab);
let cd = Simd::<u8, 16>::from(cd);
println!("ab,cd: {ab:x?}, {cd:x?}");
}
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
let abcd = simd::merge_m128_m256(ab, cd);
if_trace_simd! {
let abcd: Simd<u8, DIGIT_BATCH_SIZE> = abcd.into();
println!("abcd: {abcd:x?}");
}
abcd
}};
}
macro_rules! merge_hex_digits_into_bytes_inline {
($hex_digits:ident) => {{
let msb = simd::extract_lo_bytes($hex_digits);
let lsb = simd::extract_hi_bytes($hex_digits);
let msb1: simd::arch::__m128i;
unsafe { std::arch::asm!("vpsllw {dst}, {src}, 4", src = in(xmm_reg) msb, dst = lateout(xmm_reg) msb1) };
if_trace_simd! {
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
println!("msb1: {msb1:x?}");
}
let msb2 = msb1.and(Simd::from_array([0xf0f0u16; WIDE_BATCH_SIZE / 2]).into());
let b = msb2.or(lsb);
if_trace_simd! {
let msb: Simd<u8, WIDE_BATCH_SIZE> = msb.into();
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
let msb2: Simd<u8, WIDE_BATCH_SIZE> = msb2.into();
let lsb: Simd<u8, WIDE_BATCH_SIZE> = lsb.into();
let b: Simd<u8, WIDE_BATCH_SIZE> = b.into();
println!("| Packed | Msb | <<4 | & | Lsb | Bytes | |");
Simd::<u8, DIGIT_BATCH_SIZE>::from($hex_digits)
.to_array()
.chunks(2)
.zip(msb.to_array())
.zip(msb1.to_array())
.zip(msb2.to_array())
.zip(lsb.to_array())
.zip(b.to_array())
.for_each(|(((((chunk, msb), msb1), msb2), lsb), b)| {
println!(
"| {chunk:02x?} | {msb:x?} | {msb1:x?} | {msb2:x?} | {lsb:x?} | {b:02x?} | {ok} |",
chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16),
ok = if chunk[0] == msb && chunk[1] == lsb {
'✓'
} else {
'✗'
}
);
});
}
b
}};
}
impl HexByteSimdDecoder for HexByteDecoderB {
util::defer_impl! {
=> HexByteDecoderA;
//fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>>;
}
#[inline(always)]
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
let hi_los = hi_los.as_ptr() as *const [u8; GATHER_BATCH_SIZE];
let hex_digits = unsafe { hex_digits_simd_inline!(hi_los) };
if hex_digits.test_and_non_zero(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT).into()) {
return None;
}
Some(merge_hex_digits_into_bytes_inline!(hex_digits).into())
}
}
pub type HBD = HexByteDecoderB;
pub mod conv {
use core::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::util;
#[inline(always)]
pub const fn u8_to_u16<const N_OUT: usize>(a: [u8; N_OUT * 2]) -> [u16; N_OUT] {
unsafe { util::cast(a) }
}
#[inline(always)]
pub const fn u8x2_to_u8<const N_IN: usize>(a: [[u8; 2]; N_IN]) -> [u8; N_IN * 2] {
unsafe { util::cast(a) }
}
#[inline(always)]
pub const fn simdu8_to_simdu16<const N_OUT: usize>(
a: Simd<u8, { N_OUT * 2 }>,
) -> Simd<u16, N_OUT>
where
LaneCount<{ N_OUT * 2 }>: SupportedLaneCount,
LaneCount<N_OUT>: SupportedLaneCount,
{
unsafe { util::cast(a) }
}
}
macro_rules! decode_hex_bytes_non_vectored {
($i:ident, $ascii:ident, $bytes:ident) => {{
//let mut bad = 0u16;
let mut bad = 0u8;
while $i < $ascii.len() {
/*let b = HBD::decode_packed(unsafe { &*($ascii.as_ptr().add($i) as *const [u8; 2]) });
bad |= b;
unsafe { *$bytes.get_unchecked_mut($i >> 1) = MaybeUninit::new(b as u8) };*/
let [hi, lo] = unsafe { *($ascii.as_ptr().add($i) as *const [u8; 2]) };
let lo = hex_digit(lo);
let hi = hex_digit(hi);
bad |= lo;
bad |= hi;
/*if (hi & INVALID_BIT) | (lo & INVALID_BIT) != 0 {
println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
}*/
let b = (hi << 4) | lo;
unsafe { *$bytes.get_unchecked_mut($i >> 1) = MaybeUninit::new(b) };
$i += 2;
}
//if (bad & WIDE_INVALID_BIT) != 0 {
if (bad & INVALID_BIT) != 0 {
return false;
}
}};
}
/*simd::swizzle_indices!(MSB_INDICES = [
0, 2, 4, 6,
8, 10, 12, 14,
16, 18, 20, 22,
24, 26, 28, 30
], [_ . . . _ . . . _ . . . _ . . .]);
simd::swizzle_indices!(LSB_INDICES = [
1, 3, 5, 7,
9, 11, 13, 15,
17, 19, 21, 23,
25, 27, 29, 31
], [_ . . . _ . . . _ . . . _ . . .]);*/
#[inline(always)]
fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bool {
// these checks should always be eliminated because they are performed more efficiently
// (sometimes statically) in the callers, but they provide a major safeguard against nasty
// memory safety issues.
debug_assert_eq!(
ascii.len() >> 1 << 1,
ascii.len(),
"len of ascii is not a multiple of 2"
);
if ascii.len() >> 1 << 1 != ascii.len() {
return false;
}
debug_assert_eq!(
ascii.len() >> 1,
bytes.len(),
"len of ascii is not twice that of bytes"
);
if ascii.len() >> 1 != bytes.len() {
return false;
}
const VECTORED: bool = true;
if VECTORED {
use simd::arch;
let mut bad: arch::__m256i = simd::splat_0u8().into();
let mut i = 0;
while i < util::align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let hex_digits = unsafe {
let hi_los = ascii.as_ptr().add(i) as *const [u8; GATHER_BATCH_SIZE];
hex_digits_simd_inline!(hi_los)
};
if VALIDATE {
unsafe {
core::arch::asm!("vpor {bad}, {digits}, {bad}", bad = inout(ymm_reg) bad, digits = in(ymm_reg) hex_digits, options(pure, nomem, preserves_flags, nostack));
}
}
let buf = merge_hex_digits_into_bytes_inline!(hex_digits);
unsafe {
// vmovaps xmm0, xmmword ptr [rsi]
// vmovups xmmword ptr [rdi], xmm0
//core::arch::asm!("vmovdqu8 {}, [{}]", in(xmm_reg) buf, in(reg) bytes.as_mut_ptr().add(i >> 1) as *mut i8);
//let all: arch::__m128i = Mask::<i64, 2>::splat(true).to_int().into();
//core::arch::asm!("vpmaskmovq {}, {}, [{}]", in(xmm_reg) buf, in(xmm_reg) all, in(xmm_reg) bytes.as_mut_ptr().add(i >> 1) as *mut i8);
//core::arch::asm!("vpmaskmovq {}, {}, [{}]", in(xmm_reg) buf, in(xmm_reg) 0u64, in(xmm_reg) bytes.as_mut_ptr().add(i >> 1) as *mut i8);
// arch::_mm_storeu_epi8(bytes.as_mut_ptr().add(i >> 1) as *mut i8, buf)
//arch::_mm_maskstore_epi64(bytes.as_mut_ptr().add(i >> 1) as *mut i64, core::mem::transmute(!0u128), buf);
core::arch::asm!("vmovdqa [{}], {}", in(reg) bytes.as_mut_ptr().add(i >> 1) as *mut i8, in(xmm_reg) buf, options(preserves_flags, nostack));
};
i += DIGIT_BATCH_SIZE;
}
decode_hex_bytes_non_vectored!(i, ascii, bytes);
!bad.test_and_non_zero(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT).into())
} else {
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes);
true
}
}
/// Use of this function should be restricted to `const` contexts because it is not vectorized like
/// the non-`const` alternative.
#[inline]
pub const fn hex_bytes_sized_const<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
if N == 0 {
Some([0u8; N])
} else {
let mut bytes = MaybeUninit::uninit_array();
let mut i = 0;
while i < N * 2 {
if i >> 1 >= bytes.len() {
unsafe { core::hint::unreachable_unchecked() };
}
match hex_byte(unsafe { *ascii.get_unchecked(i) }, unsafe {
*ascii.get_unchecked(i + 1)
}) {
Some(b) => bytes[i >> 1] = MaybeUninit::new(b),
None => return None,
}
i += 2;
}
Some(unsafe { MaybeUninit::array_assume_init(bytes) })
}
}
#[inline]
pub fn hex_bytes_sized<const N: usize>(ascii: &[u8; N * 2]) -> Option<[u8; N]> {
if N == 0 {
Some([0u8; N])
} else {
let mut bytes = MaybeUninit::uninit_array();
if decode_hex_bytes_unchecked(ascii, &mut bytes) {
Some(unsafe { MaybeUninit::array_assume_init(bytes) })
} else {
None
}
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_sized_heap<const N: usize>(ascii: &[u8; N * 2]) -> Option<Box<[u8; N]>> {
if N == 0 {
Some(Box::new([0u8; N]))
} else {
let mut bytes = unsafe { Box::<[_; N]>::new_uninit().assume_init() };
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::from_raw(Box::into_raw(bytes) as *mut [u8; N]) })
} else {
None
}
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn_unsafe(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
return None;
}
let mut bytes = Box::new_uninit_slice(len);
if decode_hex_bytes_unchecked(ascii, bytes.as_mut()) {
Some(unsafe { Box::<[_]>::assume_init(bytes) })
} else {
None
}
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn_unsafe_iter(ascii: &[u8]) -> Option<Box<[u8]>> {
let len = ascii.len() >> 1;
if len << 1 != ascii.len() {
return None;
}
let mut bytes = Box::<[u8]>::new_uninit_slice(len);
for (i, [hi, lo]) in ascii.array_chunks::<2>().enumerate() {
let lo = hex_digit(*lo);
let hi = hex_digit(*hi);
if (lo & INVALID_BIT) | (hi & INVALID_BIT) != 0 {
return None;
}
let b = (hi << 4) | lo;
unsafe { *bytes.get_unchecked_mut(i) = MaybeUninit::new(b) };
}
Some(unsafe { Box::<[_]>::assume_init(bytes) })
}
#[cfg(feature = "alloc")]
#[inline]
pub fn hex_bytes_dyn(ascii: &[u8]) -> Option<Box<[u8]>> {
let iter = ascii.array_chunks::<2>();
if iter.remainder().len() != 0 {
return None;
}
iter.map(|[msb, lsb]| hex_byte(*msb, *lsb))
.collect::<Option<Vec<u8>>>()
.map(|v| v.into_boxed_slice())
}
#[cfg(test)]
mod test {
use super::*;
const BYTES: &str = "Donald J. Trump!";
const HEX_BYTES: &str = "446F6E616C64204A2E205472756D7021";
const LONG_BYTES: &str = "Dolorum distinctio ut earum quidem distinctio necessitatibus quam. Sit praesentium facere perspiciatis iure aut sunt et et. Adipisci enim rerum illum et officia nisi recusandae. Vitae doloribus ut quia ea unde consequuntur quae illum. Id eius harum est. Inventore ipsum ut sit ut vero consectetur.";
const LONG_HEX_BYTES: &str = "446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E";
struct Sample {
bytes: &'static str,
hex_bytes: &'static str,
}
const SAMPLES: &[Sample] = &[
Sample {
bytes: BYTES,
hex_bytes: HEX_BYTES,
},
Sample {
bytes: LONG_BYTES,
hex_bytes: LONG_HEX_BYTES,
},
];
const INVALID_SAMPLES: &[&str] = &[
"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722G",
"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E7365637465747572GE",
"446F6C6F72756D2064697374696E6374696G20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
"446F6C6F72756D2064697374696E637469GF20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
];
#[test]
fn test_hex_digit() {
const HEX_DIGITS_LOWER: &[char; 16] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f',
];
const HEX_DIGITS_UPPER: &[char; 16] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F',
];
for (i, digit) in HEX_DIGITS_LOWER.into_iter().enumerate() {
assert_eq!(hex_digit(*digit as u8), i as u8);
}
for (i, digit) in HEX_DIGITS_UPPER.into_iter().enumerate() {
assert_eq!(hex_digit(*digit as u8), i as u8);
}
}
#[test]
fn test_hex_digit_simd() {
const HEX_DIGITS: &[char; DIGIT_BATCH_SIZE] = &[
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0',
'1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E',
'F',
// '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
];
let mut set8 = [0u8; DIGIT_BATCH_SIZE];
for (c, b) in HEX_DIGITS.iter().zip(set8.iter_mut()) {
*b = *c as u8;
}
let mut sete = [0u8; DIGIT_BATCH_SIZE];
for i in 0..(DIGIT_BATCH_SIZE) {
sete[i] = (i as u8) % 16;
}
assert_eq!(
hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(set8)),
Simd::from_array(sete)
);
}
#[test]
fn test_hex_byte() {
const HEX_BYTES_VALID: &[([u8; 2], u8)] = &[
(['f' as u8, 'f' as u8], 0xff),
(['0' as u8, '0' as u8], 0x00),
(['1' as u8, '1' as u8], 0x11),
(['e' as u8, 'f' as u8], 0xef),
(['f' as u8, 'e' as u8], 0xfe),
(['0' as u8, 'f' as u8], 0x0f),
(['f' as u8, '0' as u8], 0xf0),
];
for (hb, b) in HEX_BYTES_VALID {
assert_eq!(hex_byte(hb[0], hb[1]), Some(*b));
assert_eq!(HexByteDecoderA::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderB::decode_unpacked(hb[0], hb[1]), *b as u16);
assert_eq!(HexByteDecoderA::decode_packed(hb), *b as u16);
assert_eq!(HexByteDecoderB::decode_packed(hb), *b as u16);
}
const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['1' as u8, 'g' as u8],
['e' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
];
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert_ne!(
HexByteDecoderA::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(
HexByteDecoderB::decode_unpacked(hb[0], hb[1]) & WIDE_INVALID_BIT,
0
);
assert_ne!(HexByteDecoderA::decode_packed(hb) & WIDE_INVALID_BIT, 0);
assert_ne!(HexByteDecoderB::decode_packed(hb) & WIDE_INVALID_BIT, 0);
}
}
#[test]
fn test_hex_byte_simd() {
const HEX_BYTES_VALID: [[u8; 2]; WIDE_BATCH_SIZE] = [
*b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34", *b"ff", *b"00", *b"11",
*b"ef", *b"fe", *b"0f", *b"f0",
*b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
// *b"ff", *b"00", *b"11", *b"ef", *b"fe", *b"0f", *b"f0", *b"34",
];
const BYTES_VALID: [u8; WIDE_BATCH_SIZE] = [
0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34, 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f,
0xf0,
0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
// 0xff, 0x00, 0x11, 0xef, 0xfe, 0x0f, 0xf0, 0x34,
];
let hex_bytes = conv::u8x2_to_u8(HEX_BYTES_VALID);
let bytes = Simd::from_array(BYTES_VALID);
if_trace_simd! {
println!("hex_bytes: {HEX_BYTES_VALID:02x?}");
println!("hex_bytes: {hex_bytes:02x?}");
println!("bytes: {BYTES_VALID:02x?}");
println!("bytes: {bytes:04x?}");
}
assert_eq!(HexByteDecoderA::decode_simd(hex_bytes), Some(bytes));
assert_eq!(HexByteDecoderB::decode_simd(hex_bytes), Some(bytes));
/*const HEX_BYTES_INVALID: &[[u8; 2]] = &[
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['1' as u8, 'g' as u8],
['e' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
['0' as u8, 'g' as u8],
['f' as u8, 'g' as u8],
];
for hb in HEX_BYTES_INVALID {
assert_eq!(hex_byte(hb[0], hb[1]), None);
assert!(hex_byte_niched(hb[0], hb[1]) & WIDE_INVALID_BIT != 0);
}*/
}
macro_rules! test_f {
(boxed $f:ident) => {
test_f!(@ $f, Box::as_ref)
};
(@ $f:ident, $trans:expr) => {
for (i, Sample { bytes, hex_bytes }) in SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
result.as_ref().map($trans),
Some(bytes.as_bytes()),
"Sample {i} ({hex_bytes:?} => {bytes:?}) did not decode correctly (expected Some)"
);
}
for (i, hex_bytes) in INVALID_SAMPLES.into_iter().enumerate() {
let result = $f(hex_bytes.as_bytes());
assert_eq!(
result.as_ref().map($trans),
None,
"Sample {i} ({hex_bytes:?}) did not decode correctly (expected None)"
);
}
};
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_iter_option() {
test_f!(boxed hex_bytes_dyn);
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_unsafe() {
test_f!(boxed hex_bytes_dyn_unsafe);
}
#[cfg(feature = "alloc")]
#[test]
fn test_dyn_unsafe_iter() {
test_f!(boxed hex_bytes_dyn_unsafe_iter);
}
}
pub mod prelude;

12
src/prelude.rs Normal file
View File

@ -0,0 +1,12 @@
pub(crate) use crate::simd;
pub(crate) use crate::util;
pub(crate) use simd::arch;
pub(crate) use simd::if_trace_simd;
pub(crate) use simd::SimdBitwise;
pub(crate) use simd::SimdLoad;
pub(crate) use simd::SimdSplat;
pub(crate) use simd::SimdTestAnd;
pub(crate) use util::array_op;
pub(crate) use util::unroll;

View File

@ -1,11 +1,16 @@
use core::simd::{LaneCount, Simd, SupportedLaneCount, SimdElement};
use core::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
use crate::util::cast;
use crate::util;
use util::cast;
const W_128: usize = 128 / 8;
const W_256: usize = 256 / 8;
const W_512: usize = 512 / 8;
/// The value which cause `vpshufb` to write 0 instead of indexing.
const SHUF_0: u8 = 0b1000_0000;
#[cfg(target_arch = "aarch64")]
pub use core::arch::aarch64 as arch;
#[cfg(target_arch = "arm")]
@ -19,13 +24,42 @@ pub use core::arch::x86 as arch;
#[cfg(target_arch = "x86_64")]
pub use core::arch::x86_64 as arch;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub use arch::{__m128i, __m256i, __m512i};
// use the maximum batch size that would be supported by AVX-512
//pub(crate) const SIMD_WIDTH: usize = 512;
pub const SIMD_WIDTH: usize = 256;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
/// The batch size used for the hex digits.
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
pub const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
#[macro_export]
macro_rules! __if_trace_simd {
($( $tt:tt )*) => {
// disabled
//{ $( $tt )* }
};
}
pub use __if_trace_simd as if_trace_simd;
pub trait IsSimd {
type Lane;
const LANES: usize;
}
impl<T, const LANES: usize> IsSimd for Simd<T, LANES> where LaneCount<LANES>: SupportedLaneCount, T: SimdElement {
impl<T, const LANES: usize> IsSimd for Simd<T, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
T: SimdElement,
{
type Lane = T;
const LANES: usize = LANES;
@ -33,7 +67,7 @@ impl<T, const LANES: usize> IsSimd for Simd<T, LANES> where LaneCount<LANES>: Su
macro_rules! specialized {
($(
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
@ -42,7 +76,7 @@ macro_rules! specialized {
)+) => {$(
#[allow(dead_code)]
#[inline(always)]
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $args )*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
@ -52,9 +86,9 @@ macro_rules! specialized {
}
}
)+};
($trait:ident for $ty:ty;
($LANES:ident =>
$(
fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
fn $name:ident$(<[$( $generics:tt )+]>)?($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
@ -62,9 +96,9 @@ macro_rules! specialized {
}
)+
) => {
impl<const $LANES: usize> $trait for $ty {$(
$(
#[inline(always)]
fn $name$(<$( $generics )+>)?($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
fn $name$(<$( $generics )+>)?($( $args )*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
@ -73,85 +107,192 @@ macro_rules! specialized {
),+
}
}
)+
};
($trait:ident for $ty:ty;
$(
fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)+
) => {
impl<const $LANES: usize> $trait for $ty {$(
specialized! { LANES =>
fn $name$(<[$( $generics:tt )+]>)?($( $args )*) -> $rt $(where [$( $where )*])? {
$(
$width $( if $cfg )? => $impl
),+
}
}
)+}
};
}
macro_rules! set1 {
($arch:ident, $inst:ident, $vec:ident, $reg:ident, $n:ident) => {{
let out: core::arch::$arch::$vec;
core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack));
macro_rules! set1_short {
($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{
// WOW. this is 12% faster than broadcast on xmm. Seems not so much on ymm (more expensive
// array init?).
const O_LANES: usize = std::mem::size_of::<$vec>() / std::mem::size_of::<$n_ty>();
util::cast::<_, $vec>(Simd::<$n_ty, O_LANES>::from_array([$n; O_LANES]))
//let out: $vec;
//core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, __m128i>($n), options(pure, nomem, preserves_flags, nostack));
//out
}};
}
macro_rules! set1_long {
($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{
//const O_LANES: usize = std::mem::size_of::<$vec>() / std::mem::size_of::<$n_ty>();
//util::cast::<_, $vec>(Simd::<$n_ty, O_LANES>::from_array([$n; O_LANES]))
let out: $vec;
core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, __m128i>($n), options(pure, nomem, preserves_flags, nostack));
out
}};
}
pub trait DoubleWidth {
type Output;
}
macro_rules! impl_double_width {
($($in:ty => $out:ty),+) => {
$(
impl DoubleWidth for $in {
type Output = $out;
}
)+
}
}
impl_double_width!(u8 => u16, u16 => u32, u32 => u64);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl_double_width!(u64 => arch::__m128i, arch::__m128i => arch::__m256i, arch::__m256i => arch::__m512i);
pub trait SimdSplat<const LANES: usize> {
type Output;
fn splat(self) -> Self::Output;
fn splat_zero() -> Self::Output;
}
pub trait SimdLoad: Sized {
fn load_128(self) -> arch::__m128i;
#[inline(always)]
fn load_256(self) -> arch::__m256i {
unsafe { util::cast(self.load_128()) }
}
#[inline(always)]
fn load_512(self) -> arch::__m512i {
unsafe { util::cast(self.load_128()) }
}
}
macro_rules! impl_load {
(($in:ident, $($in_fmt:ident)?) $in_v:ident) => {{
unsafe {
let out: _;
core::arch::asm!(concat!("vmovq {}, {:", $(stringify!($in_fmt), )? "}"), lateout(xmm_reg) out, in($in) $in_v, options(pure, nomem, preserves_flags, nostack));
out
}};
}
}}
}
macro_rules! simd_load_fallback {
($ty:ty, $vec_ty:ty, $self:ident) => {{
let mut a = std::mem::MaybeUninit::uninit_array();
a[0] = std::mem::MaybeUninit::new($self);
for i in 1..(std::mem::size_of::<$vec_ty>() / std::mem::size_of::<$ty>()) {
a[i] = std::mem::MaybeUninit::new(0);
}
let a = unsafe { std::mem::MaybeUninit::array_assume_init(a) };
Simd::from_array(a).into()
}};
}
macro_rules! impl_ops {
($(
$ty:ty {
/// The appropriate register type.
reg: $reg:ident,
reg_fmt: $($reg_fmt:ident)?,
broadcast: $broadcast:ident,
set1: [$set1_128:ident, $set1_256:ident, $set1_512:ident]
$(,)?
}
)+) => {$(
impl SimdLoad for $ty {
#[inline(always)]
fn load_128(self) -> arch::__m128i {
#[cfg(target_feature = "avx")]
{ impl_load!(($reg, $($reg_fmt)?) self) }
#[cfg(not(target_feature = "avx"))]
simd_load_fallback!($ty, arch::__m128i, self)
}
}
impl<const LANES: usize> SimdSplat<LANES> for $ty where LaneCount<LANES>: SupportedLaneCount {
type Output = Simd<$ty, LANES>;
specialized! { LANES =>
fn splat(self) -> Self::Output {
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_short!($broadcast, __m128i, xmm_reg, self: $ty)) },
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_long!($broadcast, __m256i, ymm_reg, self: $ty)) },
W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512bw") => unsafe { cast(set1_long!($broadcast, __m512i, zmm_reg, self: $ty)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(arch::$set1_128(self as i8)) },
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(arch::$set1_256(self as i8)) },
// I can't actually test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512f") => unsafe { cast(arch::$set1_512(self as i8)) },
_ => Simd::splat(self),
}
fn splat_zero() -> Self::Output {
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2") => unsafe { cast(arch::_mm_setzero_si128()) },
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx") => unsafe { cast(arch::_mm256_setzero_si256()) },
W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512f") => unsafe { cast(arch::_mm512_setzero_si512()) },
_ => Self::splat(0),
}
}
}
)+};
}
impl_ops! {
u8 {
reg: reg_byte,
reg_fmt: ,
broadcast: vpbroadcastb,
set1: [_mm_set1_epi8, _mm256_set1_epi8, _mm512_set1_epi8]
}
specialized! {
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_n<LANES>(n: u8) -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
_ => Simd::splat(n),
u16 {
reg: reg,
reg_fmt: e,
broadcast: vpbroadcastw,
set1: [_mm_set1_epi16, _mm256_set1_epi16, _mm512_set1_epi16]
}
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_u16<LANES>(n: u16) -> Simd<u16, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi16(n as i16)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi16(n as i16)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi16(n as i16)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi16(n as i16)) },
_ => Simd::splat(n),
u32 {
reg: reg,
reg_fmt: e,
broadcast: vpbroadcastd,
set1: [_mm_set1_epi32, _mm256_set1_epi32, _mm512_set1_epi32]
}
pub fn splat_0u8<LANES>() -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => unsafe { crate::util::cast(splat_n(0)) },
}
pub fn splat_0u16<LANES>() -> Simd<u16, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => unsafe { crate::util::cast(splat_n(0)) },
u64 {
reg: reg,
reg_fmt: r,
broadcast: vpbroadcastq,
set1: [_mm_set1_epi64, _mm256_set1_epi64, _mm512_set1_epi64]
}
}
@ -217,15 +358,6 @@ macro_rules! __swizzle {
pub use __swizzle as swizzle;
pub use __swizzle_indices as swizzle_indices;
#[inline(always)]
pub fn load_u64_m128(v: u64) -> arch::__m128i {
unsafe {
let out: _;
core::arch::asm!("vmovq {}, {}", lateout(xmm_reg) out, in(reg) v, options(pure, nomem, preserves_flags, nostack));
out
}
}
/// Merges the low halves of `a` and `b` into a single register like `ab`.
#[inline(always)]
pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
@ -268,7 +400,7 @@ macro_rules! extract_lohi_bytes {
#[inline(always)]
pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i {
extract_lohi_bytes!(([0xffu16; 8], vpand, vpackuswb), v)
extract_lohi_bytes!(([0x00ffu16; 8], vpand, vpackuswb), v)
}
#[inline(always)]
@ -296,13 +428,18 @@ pub trait SimdBitwise {
fn and(self, rhs: Self) -> Self;
}
pub trait SimdWiden {
/// Widens the lower bytes by spacing and zero-extending them to N bytes.
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self;
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m128i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack));
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = lateout(reg_byte) out, options(pure, nomem, nostack));
core::mem::transmute(out)
}
}
@ -314,7 +451,7 @@ impl SimdTestAnd for arch::__m256i {
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack));
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = lateout(reg_byte) out, options(pure, nomem, nostack));
core::mem::transmute(out)
}
}
@ -331,7 +468,7 @@ impl SimdBitwise for arch::__m128i {
arch::_mm_or_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = lateout(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
@ -344,7 +481,7 @@ impl SimdBitwise for arch::__m128i {
arch::_mm_and_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = lateout(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
@ -360,7 +497,7 @@ impl SimdBitwise for arch::__m256i {
arch::_mm256_or_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = lateout(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
@ -373,9 +510,192 @@ impl SimdBitwise for arch::__m256i {
arch::_mm256_and_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = lateout(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
}
macro_rules! widen_128_impl {
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:literal, $offset:literal) => {unsafe {
const LEN: usize = std::mem::size_of::<$Self>();
const MASK_MOD_EQ: usize = if $be { 0 } else { $shift };
// TODO: evaluate whether there is a more efficient approach for the ignored bytes.
const INDICES: [u8; LEN] = util::array_op!(gen[LEN] |i| {
if (i + 1) % ($shift + 1) == MASK_MOD_EQ {
((i as u8) >> $shift) + $offset
} else {
SHUF_0
}
});
const INDICES_SIMD: $Self = unsafe { util::cast(INDICES) };
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| {
// if (i + 1) % ($shift + 1) == MASK_MOD_EQ { 0xff } else { 0x00 }
//});
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| if ((i + 1) % ($shift + 1) == $offset) { 0xff } else { 0x00 });
let out: _;
core::arch::asm!("vpshufb {out}, {in}, {indices}", in = in($reg) $self, indices = in($reg) INDICES_SIMD, out = lateout($reg) out, options(pure, nomem, preserves_flags, nostack));
if_trace_simd! {
println!("Offset: {}", $offset);
println!("Indices: {:?}", INDICES);
//println!("MASK: {MASK:?}");
}
out
//SimdBitwise::and(out, util::cast(MASK))
}};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset_in:ident [$( $offset:literal ),+]) => {
match $offset_in {
$( $offset => if $be { widen_128_impl!($self, $Self, $reg, $shift, true, $offset) } else { widen_128_impl!($self, $Self, $reg, $shift, false, $offset) }, )+
_ => panic!("Unsupported widen O value"),
}
};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset:ident) => {
widen_128_impl!($self, $Self, $reg, $shift, $be, $offset [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
};
}
macro_rules! widen_256_impl {
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:literal, $offset:literal) => {unsafe {
const LEN: usize = std::mem::size_of::<$Self>();
const MASK_MOD_EQ: usize = if $be { 0 } else { $shift };
// TODO: evaluate whether there is a more efficient approach for the ignored bytes.
const INDICES_LO: [u8; LEN] = util::array_op!(gen[LEN] |i| {
let i = i % LEN;
if (i + 1) % ($shift + 1) == MASK_MOD_EQ {
((i as u8) >> $shift) + $offset
} else {
SHUF_0
}
});
const INDICES_HI: [u8; LEN] = util::array_op!(gen[LEN] |i| {
let i = (i % (LEN / 2)) + LEN / 2;
if (i + 1) % ($shift + 1) == MASK_MOD_EQ {
((i as u8) >> $shift) + $offset
} else {
SHUF_0
}
});
const INDICES_LO_SIMD: $Self = unsafe { util::cast(INDICES_LO) };
const INDICES_HI_SIMD: $Self = unsafe { util::cast(INDICES_HI) };
const HI_MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| if i < LEN / 2 { 0x00 } else { 0xff });
const HI_MASK_SIMD: $Self = unsafe { util::cast(HI_MASK) };
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| {
// if (i + 1) % ($shift + 1) == MASK_MOD_EQ { 0xff } else { 0x00 }
//});
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| if ((i + 1) % ($shift + 1) == $offset) { 0xff } else { 0x00 });
let lo: $Self;
core::arch::asm!("vpshufb {out:x}, {in:x}, {indices:x}", in = in($reg) $self, indices = in($reg) INDICES_LO_SIMD, out = lateout($reg) lo, options(pure, nomem, preserves_flags, nostack));
let hi: $Self;
core::arch::asm!("vpshufb {out:x}, {in:x}, {indices:x}", in = in($reg) $self, indices = in($reg) INDICES_HI_SIMD, out = lateout($reg) hi, options(pure, nomem, preserves_flags, nostack));
let out = lo.or(hi.and(HI_MASK_SIMD));
if_trace_simd! {
println!("Offset: {}", $offset);
println!("Indices: {:?},{:?}", INDICES_LO, INDICES_HI);
//println!("MASK: {MASK:?}");
}
out
//SimdBitwise::and(out, util::cast(MASK))
}};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset_in:ident [$( $offset:literal ),+]) => {
match $offset_in {
$( $offset => if $be { widen_256_impl!($self, $Self, $reg, $shift, true, $offset) } else { widen_256_impl!($self, $Self, $reg, $shift, false, $offset) }, )+
_ => panic!("Unsupported widen O value"),
}
};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset:ident) => {
widen_256_impl!($self, $Self, $reg, $shift, $be, $offset [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __simd__shr_64 {
($n:literal, ($in_reg:ident) $in:expr) => {{
let out: _;
std::arch::asm!(concat!("vpsrlq {dst}, {src}, ", $n), src = in($in_reg) $in, dst = lateout($in_reg) out);
out
}};
}
pub use __simd__shr_64 as shr_64;
/*impl SimdWiden for arch::__m128i {
#[inline(always)]
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self {
match N {
1 => self,
2 => widen_128_impl!(self, arch::__m128i, xmm_reg, 1, BE, O),
4 => widen_128_impl!(self, arch::__m128i, xmm_reg, 2, BE, O),
8 => widen_128_impl!(self, arch::__m128i, xmm_reg, 3, BE, O),
_ => panic!("Unsupported widen N value"),
}
}
}
impl SimdWiden for arch::__m256i {
#[inline(always)]
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self {
match N {
1 => self,
2 => widen_256_impl!(self, arch::__m256i, ymm_reg, 1, BE, O),
4 => widen_256_impl!(self, arch::__m256i, ymm_reg, 2, BE, O),
8 => widen_256_impl!(self, arch::__m256i, ymm_reg, 3, BE, O),
16 => widen_256_impl!(self, arch::__m256i, ymm_reg, 4, BE, O),
_ => panic!("Unsupported widen N value"),
}
}
}*/
#[cfg(test)]
mod test {
use super::*;
/*#[test]
fn test_widen() {
const SAMPLE_IN: [u8; 32] = [
1, 2, 3, 4, 5, 6, 7, 8,
1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16,
9, 10, 11, 12, 13, 14, 15, 16
];
const SAMPLE_OUT_LO_BE: [u8; 32] = [
0, 1, 0, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 7, 0, 8,
0, 1, 0, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 7, 0, 8
];
const SAMPLE_OUT_HI_BE: [u8; 32] = [
0, 9, 0, 10, 0, 11, 0, 12,
0, 13, 0, 14, 0, 15, 0, 16,
0, 9, 0, 10, 0, 11, 0, 12,
0, 13, 0, 14, 0, 15, 0, 16
];
let sample: arch::__m256i = Simd::from_array(SAMPLE_IN).into();
let widened = sample.widen::<2, 0, true>();
//assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_LO_BE));
let widened = sample.widen::<2, 16, true>();
assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_HI_BE));
const SAMPLE_OUT_LO_LE: [u8; 32] = [
1, 0, 2, 0, 3, 0, 4, 0,
5, 0, 6, 0, 7, 0, 8, 0,
1, 0, 2, 0, 3, 0, 4, 0,
5, 0, 6, 0, 7, 0, 8, 0
];
const SAMPLE_OUT_HI_LE: [u8; 32] = [
9, 0, 10, 0, 11, 0, 12, 0,
13, 0, 14, 0, 15, 0, 16, 0,
9, 0, 10, 0, 11, 0, 12, 0,
13, 0, 14, 0, 15, 0, 16, 0
];
let sample: arch::__m256i = Simd::from_array(SAMPLE_IN).into();
let widened = sample.widen::<2, 0, false>();
//assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_LO_LE));
let widened = sample.widen::<2, 16, false>();
assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_HI_LE));
}*/
}

57
src/test.rs Normal file
View File

@ -0,0 +1,57 @@
macro_rules! from_utf8 {
($bytes:ident) => {
unsafe { std::str::from_utf8_unchecked($bytes) }
};
}
pub const BYTES: &[u8; 16] = b"Donald J. Trump!";
pub const HEX_BYTES: &[u8; 32] = b"446F6E616C64204A2E205472756D7021";
pub const STR: &str = from_utf8!(BYTES);
pub const HEX_STR: &str = from_utf8!(HEX_BYTES);
pub const LONG_BYTES: &[u8; 297] = b"Dolorum distinctio ut earum quidem distinctio necessitatibus quam. Sit praesentium facere perspiciatis iure aut sunt et et. Adipisci enim rerum illum et officia nisi recusandae. Vitae doloribus ut quia ea unde consequuntur quae illum. Id eius harum est. Inventore ipsum ut sit ut vero consectetur.";
pub const LONG_HEX_BYTES: &[u8; 594] = b"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E";
pub const LONG_STR: &str = from_utf8!(LONG_BYTES);
pub const LONG_HEX_STR: &str = from_utf8!(LONG_HEX_BYTES);
pub struct Sample {
pub bytes: &'static str,
pub hex_bytes: &'static str,
}
pub const SAMPLES: &[Sample] = &[
Sample {
bytes: STR,
hex_bytes: HEX_STR,
},
Sample {
bytes: LONG_STR,
hex_bytes: LONG_HEX_STR,
},
];
pub const INVALID_SAMPLES: &[&str] = &[
"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722G",
"446F6C6F72756D2064697374696E6374696F20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E7365637465747572GE",
"446F6C6F72756D2064697374696E6374696G20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
"446F6C6F72756D2064697374696E637469GF20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
];
#[doc(hidden)]
#[macro_export]
macro_rules! __test__name {
($group:literal, $f:literal) => {
concat!("[", $group, "] - ", $f)
};
($group:expr, $f:literal) => {
std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str())
};
($group:expr, $f:expr) => {
std::boxed::Box::leak(format!(name!("{}", "{}"), $group, $f).into_boxed_str())
};
($group:literal, $f:expr) => {
std::boxed::Box::leak(format!(name!($group, "{}"), $f).into_boxed_str())
};
}
pub use __test__name as name;

View File

@ -2,21 +2,39 @@
#[macro_export]
macro_rules! __array_op {
(gen[$len:expr] |$i:pat_param| $val:expr) => {{
let mut out = std::mem::MaybeUninit::uninit_array();
let mut out = std::mem::MaybeUninit::uninit_array::<$len>();
let mut i = 0;
while i < $len {
out[i] = std::mem::MaybeUninit::new(match i { $i => $val });
out[i] = std::mem::MaybeUninit::new(match i {
$i => $val,
});
i += 1;
}
unsafe { std::mem::MaybeUninit::array_assume_init(out) }
#[allow(unused_unsafe)]
unsafe {
std::mem::MaybeUninit::array_assume_init(out)
}
}};
(map[$len:expr, $src:expr] |$i:pat_param, $s:pat_param| $val:expr) => {{
$crate::util::array_op!(gen[$len] |i| match i { $i => match $src[i] { $s => $val } })
$crate::util::array_op!(
gen[$len]
| i
| match i {
$i => match $src[i] {
$s => $val,
},
}
)
}};
}
pub use __array_op as array_op;
#[inline]
pub const fn cast_u8_u32<const N: usize>(arr: [u8; N]) -> [u32; N] {
array_op!(map[N, arr] |_, v| v as u32)
}
#[doc(hidden)]
#[macro_export]
macro_rules! __defer_impl {
@ -101,6 +119,71 @@ pub const fn align_up_to<const N: usize>(n: usize) -> usize {
return (n + (N - 1)) >> shift << shift;
}
#[inline(always)]
pub unsafe fn alloc_aligned_box<T, const ALIGN: usize>() -> Box<T> {
let ptr = alloc_aligned(
std::mem::size_of::<T>(),
std::cmp::max(std::mem::align_of::<T>(), ALIGN),
);
Box::from_raw(ptr as *mut T)
}
#[inline(always)]
pub unsafe fn alloc_aligned_box_slice<T, const ALIGN: usize>(len: usize) -> Box<[T]> {
let size = std::cmp::max(std::mem::size_of::<T>(), std::mem::align_of::<T>());
let ptr = alloc_aligned(size * len, std::cmp::max(std::mem::align_of::<T>(), ALIGN));
Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr as *mut T, len))
}
#[inline(always)]
pub unsafe fn alloc_aligned(len: usize, align: usize) -> *mut u8 {
let layout = std::alloc::Layout::from_size_align_unchecked(len, align);
std::alloc::alloc(layout)
}
#[macro_export]
macro_rules! __util__unroll {
(let [$( $x:ident ),+] => |$y:pat_param| $expr:expr) => {
$crate::util::unroll!(let [$( $x: ($x) ),+] => |$y| $expr);
};
(let [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
$crate::util::unroll!(@ [$] [$( $id: ($( $x ),+) ),+] => |$( $y ),+| $($expr)+)
};
(@ [$dollar:tt] [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
macro_rules! __unrolled {
($dollar id:ident, $dollar ( $dollar z:tt ),+) => {
#[allow(unused_parens)]
let ($( $y ),+) = ($dollar ( $dollar z ),+);
let $dollar id = $($expr)+;
};
}
$( __unrolled!($id, $( $x ),+); )+
//$(
// let ($( $y ),+) = ($( $x ),+);
// $expr;
//)+
};
([$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
$crate::util::unroll!(@ [$] [$( ($( $x ),+) ),+] => |$( $y ),+| $($expr)+)
};
(@ [$dollar:tt] [$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
macro_rules! __unrolled {
($dollar ( $dollar z:tt ),+) => {
#[allow(unused_parens)]
let ($( $y ),+) = ($dollar ( $dollar z ),+);
$($expr)+;
};
}
$( __unrolled!($( $x ),+); )+
//$(
// let ($( $y ),+) = ($( $x ),+);
// $expr;
//)+
};
}
pub use __util__unroll as unroll;
#[cfg(test)]
mod test {
use super::*;
@ -113,4 +196,33 @@ mod test {
assert_eq!(align_down_to::<16>(15), 0);
assert_eq!(align_down_to::<16>(17), 16);
}
#[test]
pub fn test_array_op_gen() {
assert_eq!(array_op!(gen[4] | i | i), [0, 1, 2, 3]);
assert_eq!(
array_op!(gen[16] | i | i),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
);
assert_eq!(array_op!(gen[4] | i | i as u8), [0u8, 1, 2, 3]);
assert_eq!(
array_op!(gen[16] | i | i as u8),
[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
);
const A: [u8; 4] = array_op!(gen[4] | i | i as u8);
const B: [u8; 16] = array_op!(gen[16] | i | i as u8);
assert_eq!(A, [0u8, 1, 2, 3]);
assert_eq!(B, [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
const LEN: usize = std::mem::size_of::<std::arch::x86_64::__m256i>();
const INDICES: [u8; LEN] = array_op!(gen[LEN] | i | (i as u8) >> 1);
assert_eq!(
INDICES,
[
0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12,
13, 13, 14, 14, 15, 15
]
);
}
}