Implemented very fast encoder

This commit is contained in:
Michael Pfaff 2022-11-03 23:38:54 -04:00
parent 5a98826784
commit aba7d5148a
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
10 changed files with 1074 additions and 228 deletions

View File

@ -7,6 +7,7 @@ edition = "2021"
default = ["std"]
alloc = []
std = ["alloc"]
test = []
[dependencies]
@ -15,5 +16,11 @@ criterion = { version = "0.4", features = [ "real_blackbox" ] }
rand = "0.8.5"
[[bench]]
name = "bench"
name = "dec"
harness = false
required-features = ["test"]
[[bench]]
name = "enc"
harness = false
required-features = ["test"]

View File

@ -5,7 +5,9 @@
use std::mem::MaybeUninit;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use fast_hex::*;
use fast_hex::dec::*;
use fast_hex::simd::{DIGIT_BATCH_SIZE, WIDE_BATCH_SIZE};
use fast_hex::test::name;
const ASCII_BYTES: &[u8; 16] = b"Donald J. Trump!";
const HEX_BYTES: &[u8; ASCII_BYTES.len() * 2] = b"446F6E616C64204A2E205472756D7021";
@ -133,22 +135,6 @@ pub fn bench_256(c: &mut Criterion) {
benchmark_sized::<{ ASCII_BYTES_LONG.len() }, false>("256", HEX_BYTES_LONG, c);
}
const fn __make_hex_chars() -> [u8; 16] {
let mut chars = [0u8; 16];
let mut i = 0u8;
while (i as usize) < chars.len() {
chars[i as usize] = if i < 10 {
'0' as u8 + i
} else {
'a' as u8 + i - 10
};
i += 1;
}
chars
}
const HEX_CHARS: [u8; 16] = __make_hex_chars();
trait SliceRandom {
type Item;
@ -227,7 +213,7 @@ pub fn bench_2k(c: &mut Criterion) {
unsafe { std::mem::MaybeUninit::uninit().assume_init() };
let mut rng = rand::thread_rng();
for b in hex_bytes.iter_mut() {
*b = MaybeUninit::new(*HEX_CHARS.choose(&mut rng).unwrap());
*b = MaybeUninit::new(*fast_hex::enc::HEX_CHARS_LOWER.choose(&mut rng).unwrap());
}
let hex_bytes: [u8; LEN2] = unsafe { std::mem::transmute(hex_bytes) };
let bytes = match hex_bytes_dyn(hex_bytes.as_ref()) {
@ -250,7 +236,7 @@ pub fn bench_512k(c: &mut Criterion) {
unsafe { std::mem::transmute(Box::<[u8; LEN2]>::new_uninit()) };
let mut rng = rand::thread_rng();
for b in hex_bytes.iter_mut() {
*b = MaybeUninit::new(*HEX_CHARS.choose(&mut rng).unwrap());
*b = MaybeUninit::new(*fast_hex::enc::HEX_CHARS_LOWER.choose(&mut rng).unwrap());
}
let hex_bytes: Box<[u8; LEN2]> = unsafe { std::mem::transmute(hex_bytes) };
let bytes = match hex_bytes_dyn(hex_bytes.as_ref()) {
@ -273,7 +259,7 @@ pub fn bench_1_6m(c: &mut Criterion) {
unsafe { std::mem::transmute(Box::<[u8; LEN2]>::new_uninit()) };
let mut rng = rand::thread_rng();
for b in hex_bytes.iter_mut() {
*b = MaybeUninit::new(*HEX_CHARS.choose(&mut rng).unwrap());
*b = MaybeUninit::new(*fast_hex::enc::HEX_CHARS_LOWER.choose(&mut rng).unwrap());
}
let hex_bytes: Box<[u8; LEN2]> = unsafe { std::mem::transmute(hex_bytes) };
let bytes = match hex_bytes_dyn(hex_bytes.as_ref()) {
@ -400,6 +386,12 @@ pub fn bench_nano_hex_byte(c: &mut Criterion) {
bench_decoder::<HexByteDecoderA>(c, stringify!(HexByteDecoderA));
}
fn verification() {
fast_hex::simd::if_trace_simd! {
panic!("Illegal benchmark state: SIMD tracing enabled");
}
}
criterion_group!(
decode_benches,
bench_16,
@ -410,4 +402,4 @@ criterion_group!(
);
criterion_group!(micro_benches, bench_micro_hex_digit, bench_micro_hex_byte);
criterion_group!(nano_benches, bench_nano_hex_digit, bench_nano_hex_byte);
criterion_main!(decode_benches, micro_benches, nano_benches);
criterion_main!(verification, decode_benches, micro_benches, nano_benches);

61
benches/enc.rs Normal file
View File

@ -0,0 +1,61 @@
#![feature(generic_const_exprs)]
#![feature(new_uninit)]
#![feature(portable_simd)]
use std::mem::MaybeUninit;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use fast_hex::enc::{Encode as _, Encoder};
use fast_hex::test::name;
type Enc = Encoder<true>;
const ASCII_BYTES: &[u8; 16] = b"Donald J. Trump!";
const HEX_BYTES: &[u8; ASCII_BYTES.len() * 2] = b"446F6E616C64204A2E205472756D7021";
const ASCII_BYTES_LONG: &[u8; 256] = b"Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!Donald J. Trump!";
const HEX_BYTES_LONG: &[u8; ASCII_BYTES_LONG.len() * 2] = b"446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021446F6E616C64204A2E205472756D7021";
fn benchmark_sized<const N: usize, const HEAP_ONLY: bool>(
name: &str,
bytes: &[u8; N],
c: &mut Criterion,
) where
[(); N * 2]:,
{
if !HEAP_ONLY {
c.bench_function(name!(name, "enc const"), |b| {
b.iter(|| Enc::enc_const(black_box(bytes)))
});
c.bench_function(name!(name, "enc sized"), |b| {
b.iter(|| Enc::enc_sized(black_box(bytes)))
});
}
c.bench_function(name!(name, "enc sized heap"), |b| {
b.iter(|| Enc::enc_sized_heap(black_box(bytes)))
});
benchmark(name, bytes, c);
}
fn benchmark(name: &str, bytes: &[u8], c: &mut Criterion) {
c.bench_function(name!(name, "enc slice"), |b| {
b.iter(|| Enc::enc_slice(black_box(bytes)))
});
}
pub fn bench_16(c: &mut Criterion) {
benchmark_sized::<{ ASCII_BYTES.len() }, false>("16", ASCII_BYTES, c);
}
pub fn bench_256(c: &mut Criterion) {
benchmark_sized::<{ ASCII_BYTES_LONG.len() }, false>("256", ASCII_BYTES_LONG, c);
}
fn verification() {
fast_hex::simd::if_trace_simd! {
panic!("Illegal benchmark state: SIMD tracing enabled");
}
}
criterion_group!(encode_benches, bench_16, bench_256,);
criterion_main!(verification, encode_benches);

View File

@ -6,29 +6,12 @@ use core::simd::*;
#[cfg(feature = "alloc")]
use alloc::{boxed::Box, vec::Vec};
use crate::{simd, util};
use crate::prelude::*;
use simd::SimdBitwise as _;
use simd::SimdTestAnd as _;
use simd::{if_trace_simd, SIMD_WIDTH};
use util::array_op;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
/// The batch size used for the hex digits.
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
use simd::{DIGIT_BATCH_SIZE, GATHER_BATCH_SIZE, SIMD_WIDTH, WIDE_BATCH_SIZE};
const VALIDATE: bool = true;
#[inline]
const fn cast_u8_u32<const N: usize>(arr: [u8; N]) -> [u32; N] {
array_op!(map[N, arr] |_, v| v as u32)
}
pub const INVALID_BIT: u8 = 0b1000_0000;
pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000;
@ -54,7 +37,7 @@ const ASCII_DIGITS: [u8; 256] = {
)
};
const __ASCII_DIGITS_SIMD: [u32; 256] = cast_u8_u32(ASCII_DIGITS);
const __ASCII_DIGITS_SIMD: [u32; 256] = util::cast_u8_u32(ASCII_DIGITS);
const ASCII_DIGITS_SIMD: *const i32 = &__ASCII_DIGITS_SIMD as *const u32 as *const i32;
@ -74,7 +57,7 @@ where
&ASCII_DIGITS,
Mask::splat(true),
ascii.cast(),
simd::splat_0u8::<LANES>(),
u8::splat_zero(),
)
}
}
@ -144,10 +127,7 @@ macro_rules! hex_digits_simd_inline {
println!("hi_los: {:x?}", *$ptr);
}
let a = *$ptr;
let b = *$ptr.add(1);
let c = *$ptr.add(2);
let d = *$ptr.add(3);
unroll!(let [a: (0), b: (1), c: (2), d: (3)] => |i| *$ptr.add(i));
if_trace_simd! {
let f = |x| __ASCII_DIGITS_SIMD[x as usize];
@ -161,50 +141,29 @@ macro_rules! hex_digits_simd_inline {
);
}
let a = Simd::from_array(a);
let b = Simd::from_array(b);
let c = Simd::from_array(c);
let d = Simd::from_array(d);
unroll!(let [a, b, c, d] => |x| Simd::from_array(x));
let a = a.cast::<u32>();
let b = b.cast::<u32>();
let c = c.cast::<u32>();
let d = d.cast::<u32>();
unroll!(let [a, b, c, d] => |x| x.cast::<u32>());
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
let a = a.into();
let b = b.into();
let c = c.into();
let d = d.into();
unroll!(let [a, b, c, d] => |x| x.into());
let a = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, a, 4);
let b = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, b, 4);
let c = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, c, 4);
let d = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, d, 4);
unroll!(let [a, b, c, d] => |x| simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, x, 4));
let a = Simd::<u32, GATHER_BATCH_SIZE>::from(a).cast::<u8>();
let b = Simd::<u32, GATHER_BATCH_SIZE>::from(b).cast::<u8>();
let c = Simd::<u32, GATHER_BATCH_SIZE>::from(c).cast::<u8>();
let d = Simd::<u32, GATHER_BATCH_SIZE>::from(d).cast::<u8>();
unroll!(let [a, b, c, d] => |x| Simd::<u32, GATHER_BATCH_SIZE>::from(x).cast::<u8>());
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
// load the 64-bit integers into registers
let a = simd::load_u64_m128(util::cast(a));
let b = simd::load_u64_m128(util::cast(b));
let c = simd::load_u64_m128(util::cast(c));
let d = simd::load_u64_m128(util::cast(d));
unroll!(let [a, b, c, d] => |x| util::cast::<_, u64>(x).load_128());
if_trace_simd! {
let a = Simd::<u8, 16>::from(a);
let b = Simd::<u8, 16>::from(b);
let c = Simd::<u8, 16>::from(c);
let d = Simd::<u8, 16>::from(d);
unroll!(let [a, b, c, d] => |x| Simd::<u8, 16>::from(x));
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
@ -214,8 +173,7 @@ macro_rules! hex_digits_simd_inline {
let cd = simd::merge_lo_hi_m128(c, d);
if_trace_simd! {
let ab = Simd::<u8, 16>::from(ab);
let cd = Simd::<u8, 16>::from(cd);
unroll!(let [ab, cd] => |x| Simd::<u8, 16>::from(x));
println!("ab,cd: {ab:x?}, {cd:x?}");
}
@ -237,20 +195,16 @@ macro_rules! merge_hex_digits_into_bytes_inline {
let lsb = simd::extract_hi_bytes($hex_digits);
let msb1: simd::arch::__m128i;
unsafe { std::arch::asm!("vpsllw {dst}, {src}, 4", src = in(xmm_reg) msb, dst = lateout(xmm_reg) msb1) };
unsafe { std::arch::asm!("vpsllq {dst}, {src}, 4", src = in(xmm_reg) msb, dst = lateout(xmm_reg) msb1) };
if_trace_simd! {
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
println!("msb1: {msb1:x?}");
}
let msb2 = msb1.and(Simd::from_array([0xf0f0u16; WIDE_BATCH_SIZE / 2]).into());
let msb2 = msb1.and(0xf0u8.splat().into());
let b = msb2.or(lsb);
if_trace_simd! {
let msb: Simd<u8, WIDE_BATCH_SIZE> = msb.into();
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
let msb2: Simd<u8, WIDE_BATCH_SIZE> = msb2.into();
let lsb: Simd<u8, WIDE_BATCH_SIZE> = lsb.into();
let b: Simd<u8, WIDE_BATCH_SIZE> = b.into();
unroll!(let [msb, msb1, msb2, lsb, b] => |x| Simd::<u8, WIDE_BATCH_SIZE>::from(x));
println!("| Packed | Msb | <<4 | & | Lsb | Bytes | |");
Simd::<u8, DIGIT_BATCH_SIZE>::from($hex_digits)
@ -291,7 +245,7 @@ impl HexByteSimdDecoder for HexByteDecoderA {
let hex_digits = unsafe { hex_digits_simd_inline!(hi_los) };
if hex_digits.test_and_non_zero(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT).into()) {
if hex_digits.test_and_non_zero(INVALID_BIT.splat().into()) {
return None;
}
@ -302,30 +256,12 @@ impl HexByteSimdDecoder for HexByteDecoderA {
pub type HBD = HexByteDecoderA;
pub mod conv {
use core::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::util;
#[inline(always)]
pub const fn u8_to_u16<const N_OUT: usize>(a: [u8; N_OUT * 2]) -> [u16; N_OUT] {
unsafe { util::cast(a) }
}
#[inline(always)]
pub const fn u8x2_to_u8<const N_IN: usize>(a: [[u8; 2]; N_IN]) -> [u8; N_IN * 2] {
unsafe { util::cast(a) }
}
#[inline(always)]
pub const fn simdu8_to_simdu16<const N_OUT: usize>(
a: Simd<u8, { N_OUT * 2 }>,
) -> Simd<u16, N_OUT>
where
LaneCount<{ N_OUT * 2 }>: SupportedLaneCount,
LaneCount<N_OUT>: SupportedLaneCount,
{
unsafe { util::cast(a) }
}
}
macro_rules! decode_hex_bytes_non_vectored {
@ -338,10 +274,8 @@ macro_rules! decode_hex_bytes_non_vectored {
unsafe { *$bytes.get_unchecked_mut($i >> 1) = MaybeUninit::new(b as u8) };*/
let [hi, lo] = unsafe { *($ascii.as_ptr().add($i) as *const [u8; 2]) };
let lo = hex_digit(lo);
let hi = hex_digit(hi);
bad |= lo;
bad |= hi;
unroll!(let [lo, hi] => |x| hex_digit(x));
unroll!([(lo), (hi)] => |x| bad |= x);
/*if (hi & INVALID_BIT) | (lo & INVALID_BIT) != 0 {
println!("bad hex byte at {} ({}{})", $i, $ascii[$i] as char, $ascii[$i + 1] as char);
}*/
@ -381,9 +315,7 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
const VECTORED: bool = true;
if VECTORED {
use simd::arch;
let mut bad: arch::__m256i = simd::splat_0u8().into();
let mut bad: arch::__m256i = u8::splat_zero().into();
let mut i = 0;
while i < util::align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let hex_digits = unsafe {
@ -408,7 +340,7 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
}
decode_hex_bytes_non_vectored!(i, ascii, bytes);
!bad.test_and_non_zero(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT).into())
!bad.test_and_non_zero(INVALID_BIT.splat().into())
} else {
let mut i = 0;
decode_hex_bytes_non_vectored!(i, ascii, bytes);

View File

@ -1,15 +1,28 @@
//! SIMD-accelerated hex encoding.
use std::mem::MaybeUninit;
use std::simd::*;
use crate::{simd, util};
use crate::prelude::*;
use util::array_op;
use simd::{DIGIT_BATCH_SIZE, GATHER_BATCH_SIZE, SIMD_WIDTH, WIDE_BATCH_SIZE};
const HEX_CHARS_LOWER: [u8; 16] = array_op!(map[16, ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f']] |_, c| c as u8);
const HEX_CHARS_UPPER: [u8; 16] =
const REQUIRED_ALIGNMENT: usize = 64;
pub const HEX_CHARS_LOWER: [u8; 16] = array_op!(map[16, ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f']] |_, c| c as u8);
pub const HEX_CHARS_UPPER: [u8; 16] =
array_op!(map[16, HEX_CHARS_LOWER] |_, c| (c as char).to_ascii_uppercase() as u8);
const __HEX_CHARS_LOWER_SIMD: [u32; 16] = util::cast_u8_u32(HEX_CHARS_LOWER);
const HEX_CHARS_LOWER_SIMD: *const i32 = &__HEX_CHARS_LOWER_SIMD as *const u32 as *const i32;
const __HEX_CHARS_UPPER_SIMD: [u32; 16] = util::cast_u8_u32(HEX_CHARS_UPPER);
const HEX_CHARS_UPPER_SIMD: *const i32 = &__HEX_CHARS_UPPER_SIMD as *const u32 as *const i32;
// TODO: add a check for endianness (current is assumed LE)
const HEX_BYTES_LOWER: [u16; 256] = array_op!(gen[256] |i| ((HEX_CHARS_LOWER[(i & 0xf0) >> 4] as u16)) | ((HEX_CHARS_LOWER[i & 0x0f] as u16) << 8));
const HEX_BYTES_UPPER: [u16; 256] = array_op!(gen[256] |i| ((HEX_CHARS_UPPER[(i & 0xf0) >> 4] as u16)) | ((HEX_CHARS_UPPER[i & 0x0f] as u16) << 8));
macro_rules! select {
($cond:ident ? $true:ident : $false:ident) => {
if $cond {
@ -27,23 +40,350 @@ macro_rules! select {
};
}
macro_rules! const_impl {
#[inline(always)]
fn nbl_to_ascii<const UPPER: bool>(nbl: u8) -> u8 {
// fourth bit set if true
let at_least_10 = {
let b1 = nbl & 0b1010;
let b2 = nbl & 0b1100;
((nbl >> 1) | (b2 & (b2 << 1)) | (b1 & (b1 << 2))) & 0b1000
};
// 6th bit is always 1 with a-z and 0-9
let b6_val = if UPPER { (at_least_10 ^ 0b1000) << 2 } else { 0b100000 };
// 5th bit is always 1 with 0-9
let b5_val = (at_least_10 ^ 0b1000) << 1;
// 7th bit is always 1 with a-z and A-Z
let b7_val = at_least_10 << 3;
// fill all bits with the value of the 4th bit
let is_at_least_10_all_mask = (((at_least_10 << 4) as i8) >> 7) as u8;
// sub 9 if we're >=10
// a-z and A-Z start at ..0001 rather than ..0000 like 0-9, so we sub 9, not 10
let sub = 9 & is_at_least_10_all_mask;
// apply the sub, then OR in the constants
(nbl - sub) | b6_val | b5_val | b7_val
}
#[inline(always)]
fn nbl_wide_to_ascii<const UPPER: bool>(nbl: u16) -> u16 {
// fourth bit set if true
let at_least_10 = {
let b1 = nbl & 0b1010;
let b2 = nbl & 0b1100;
((nbl >> 1) | (b2 & (b2 << 1)) | (b1 & (b1 << 2))) & 0b1000
};
// mask used don the 6th bit.
let b6_val = if UPPER { (at_least_10 ^ 0b1000) << 2 } else { 0b100000 };
let b5_val = (at_least_10 ^ 0b1000) << 1;
let b7_val = at_least_10 << 3;
// sign extend the 1 if set
let is_at_least_10_all_mask = (((at_least_10 << 12) as i16) >> 15) as u16;
let sub = 9 & is_at_least_10_all_mask;
let c = (nbl - sub) | b6_val | b5_val | b7_val;
c
}
// the way this is used, is by inserting the u16 directly into a byte array, so on a little-endian system (assumed in the code), we need the low byte shifted to the left, which seems counterintuitive.
#[inline(always)]
fn byte_to_ascii<const UPPER: bool>(byte: u8) -> u16 {
//let byte = byte as u16;
//nbl_wide_to_ascii::<UPPER>((byte & 0xf0) >> 4) | (nbl_wide_to_ascii::<UPPER>(byte & 0x0f) << 8)
(nbl_to_ascii::<UPPER>((byte & 0xf0) >> 4) as u16) | ((nbl_to_ascii::<UPPER>(byte & 0x0f) as u16) << 8)
}
macro_rules! const_impl1 {
($UPPER:ident, $src:ident, $dst:ident) => {{
let mut i = 0;
const UNROLL: usize = 8;
let ub = $dst.len();
let aub = util::align_down_to::<{ 2 * UNROLL }>(ub);
let mut src = $src.as_ptr();
let mut dst = $dst.as_mut_ptr() as *mut u8;
while i < aub {
unsafe {
let [b1, b2, b3, b4, b5, b6, b7, b8] = [(); UNROLL];
unroll!(let [b1, b2, b3, b4, b5, b6, b7, b8] => |_| {
let b = *src;
src = src.add(1);
b
});
unroll!([(0, b1), (2, b2), (4, b3), (6, b4), (8, b5), (10, b6), (12, b7), (14, b8)] => |i, b| {
*dst.add(i) = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b >> 4) as usize)
});
unroll!([(0, b1), (2, b2), (4, b3), (6, b4), (8, b5), (10, b6), (12, b7), (14, b8)] => |i, b| {
*dst.add(i + 1) = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b & 0x0f) as usize)
});
dst = dst.add(2 * UNROLL);
i += 2 * UNROLL;
}
}
while i < ub {
let b = $src[i >> 1];
$dst[i] = MaybeUninit::new(select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER)[(b >> 4) as usize]);
$dst[i + 1] = MaybeUninit::new(select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER)[(b & 0x0f) as usize]);
i += 2;
unsafe {
let b = *src;
*dst = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b >> 4) as usize);
dst = dst.add(1);
*dst = *select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b & 0x0f) as usize);
dst = dst.add(1);
i += 2;
src = src.add(1);
}
}
}};
}
#[inline(always)]
fn u64_to_ne_u16(v: u64) -> [u16; 4] {
unsafe { std::mem::transmute(v.to_ne_bytes()) }
}
macro_rules! const_impl {
($UPPER:ident, $src:ident, $dst:ident) => {{
let mut i = 0;
const UNROLL: usize = 8;
let ub = $src.len();
let aub = util::align_down_to::<{ UNROLL }>(ub);
let mut src = $src.as_ptr() as *const u64;
//let mut dst = $dst.as_mut_ptr() as *mut u16;
let mut dst = $dst.as_mut_ptr() as *mut u16;
while i < aub {
unsafe {
let [b1, b2, b3, b4, b5, b6, b7, b8] = (src.read_unaligned()).to_ne_bytes();
src = src.add(1);
//let [b1, b2, b3, b4, b5, b6, b7, b8] = [(); UNROLL];
//unroll!(let [b1, b2, b3, b4, b5, b6, b7, b8] => |_| {
// let b = *src;
// src = src.add(1);
// b
//});
//let mut buf1: u64 = 0;
//let mut buf2: u64 = 0;
unroll!(let [b1, b2, b3, b4, b5, b6, b7, b8] => |b| {
*select!($UPPER ? HEX_BYTES_UPPER : HEX_BYTES_LOWER).get_unchecked(b as usize)
});
//unroll!(let [b1: (0, b1), b2: (1, b2), b3: (2, b3), b4: (3, b4), b5: (4, b5), b6: (5, b6), b7: (6, b7), b8: (7, b8)] => |i, v| {
// if i < 4 {
// (v as u64) << (i * 16)
// } else {
// (v as u64) << ((i - 4) * 16)
// }
//});
unroll!([(0, b1), (1, b2), (2, b3), (3, b4), (4, b5), (5, b6), (6, b7), (7, b8)] => |_, v| {
//*dst = *select!($UPPER ? HEX_BYTES_UPPER : HEX_BYTES_LOWER).get_unchecked(b as usize);
*dst = v;
//if i < 4 {
// //println!("[{i}] {v:064b}");
// buf1 |= v;
//} else {
// //println!("[{i}] {v:064b}");
// buf2 |= v;
//}
// if i < 4 {
// buf1[i] = MaybeUninit::new(v);
// } else {
// buf2[i - 4] = MaybeUninit::new(v);
// }
//*dst = byte_to_ascii::<$UPPER>(b);
dst = dst.add(1);
});
// TODO: would using vector store actually be faster here (particularly for the
// heap variant)
//assert!(dst < ($dst.as_mut_ptr() as *mut u64).add($dst.len()));
//*dst = buf1;
//dst = dst.add(1);
//assert!(dst < ($dst.as_mut_ptr() as *mut u64).add($dst.len()));
//*dst = buf2;
//dst = dst.add(1);
i += UNROLL;
}
}
let mut src = src as *const u8;
let mut dst = dst as *mut u16;
while i < ub {
unsafe {
let b = *src;
*dst = *select!($UPPER ? HEX_BYTES_UPPER : HEX_BYTES_LOWER).get_unchecked(b as usize);
//*dst = byte_to_ascii::<$UPPER>(b);
dst = dst.add(1);
src = src.add(1);
i += 1;
}
}
}};
}
/// The `$dst` must be 32-byte aligned.
macro_rules! common_impl {
($UPPER:ident, $src:ident, $dst:ident) => {
const_impl!($UPPER, $src, $dst)
};
(@disabled $UPPER:ident, $src:ident, $dst:ident) => {{
let mut i = 0;
let ub = $dst.len();
let aub = util::align_down_to::<DIGIT_BATCH_SIZE>(ub);
let mut src = $src.as_ptr();
let mut dst = $dst.as_mut_ptr();
while i < aub {
unsafe {
//let hi_los = $src.as_ptr().add(i) as *const [u8; GATHER_BATCH_SIZE];
//let chunk = $src.as_ptr().add(i >> 1) as *const [u8; WIDE_BATCH_SIZE];
//let chunk = *chunk;
//let chunk: simd::arch::__m128i = Simd::from_array(chunk).into();
let chunk: simd::arch::__m128i;
std::arch::asm!("vmovdqu {dst}, [{src}]", src = in(reg) src, dst = lateout(xmm_reg) chunk);
let hi = chunk.and(0xf0u8.splat().into());
let hi: simd::arch::__m128i = simd::shr_64!(4, (xmm_reg) hi);
let lo = chunk.and(0x0fu8.splat().into());
unroll!(let [hi, lo] => |x| Simd::<u8, WIDE_BATCH_SIZE>::from(x));
if_trace_simd! {
println!("hi,lo: {hi:02x?}, {lo:02x?}");
}
// TODO: find a more efficient approach
let hi = hi.cast::<u32>();
let lo = lo.cast::<u32>();
// just trunc these
let a: simd::arch::__m256i = util::cast(hi);
let c: simd::arch::__m256i = util::cast(lo);
// need to shift these over
unroll!(let [hi, lo] => |x| (&x as *const _ as *const [u32; 8]).add(1));
let b: simd::arch::__m256i = Simd::from_array(*hi).into();
let d: simd::arch::__m256i = Simd::from_array(*lo).into();
//unroll!(let [hi, lo] => |x| util::cast::<_, simd::arch::__m256i>(x));
//if_trace_simd! {
// unroll!(let [hi, lo] => |x| Simd::<u8, DIGIT_BATCH_SIZE>::from(x));
// println!("hi,lo: {hi:02x?}, {lo:02x?}");
//}
//let a = hi;
//let b: simd::arch::__m128i;
//std::arch::asm!("vpermq {:x}, {:y}, 1", lateout(xmm_reg) b, in(xmm_reg) hi);
//let b: simd::arch::__m128i;
//std::arch::asm!("vextracti128 {:x}, {:y}, 1", lateout(xmm_reg) b, in(xmm_reg) hi);
//let c = lo;
//let d: simd::arch::__m128i;
//std::arch::asm!("vextracti128 {:x}, {:y}, 1", lateout(xmm_reg) d, in(xmm_reg) lo);
//unroll!(let [a, b, c, d] => |x| {
// let o: simd::arch::__m256i;
// std::arch::asm!("vpmovzxdq {}, {}", lateout(ymm_reg) o, in(xmm_reg) x);
// o
//});
//let a = hi.widen::<4, 0, false>();
//let b = hi.widen::<4, 8, false>();
//let c = lo.widen::<4, 0, false>();
//let d = lo.widen::<4, 8, false>();
if_trace_simd! {
unroll!(let [a, b, c, d] => |x| Simd::<u32, GATHER_BATCH_SIZE>::from(x));
println!("a,b,c,d: {a:02x?}, {b:02x?}, {c:02x?}, {d:02x?}");
}
unroll!(let [a, b, c, d] => |x| simd::arch::__m256i::from(x));
// let indices: simd::arch::__m256i = Simd::from_array([]).into();
// std::arch::asm!("vpshufb {out:y}, {in:y}, {indices}", out = lateout(ymm_reg) a, in = in(ymm_reg) hi, indices = in(ymm_reg) );
unroll!(let [a, b, c, d] => |x| simd::arch::_mm256_i32gather_epi32(select!($UPPER ? HEX_CHARS_UPPER_SIMD : HEX_CHARS_LOWER_SIMD), x, 4));
unroll!(let [a, b, c, d] => |x| Simd::<u32, GATHER_BATCH_SIZE>::from(x).cast::<u8>());
if_trace_simd! {
println!("a,b,c,d: {a:02x?}, {b:02x?}, {c:02x?}, {d:02x?}");
}
// load the 64-bit integers into registers
unroll!(let [a, b, c, d] => |x| util::cast::<_, u64>(x).load_128());
if_trace_simd! {
unroll!(let [a, b, c, d] => |x| Simd::<u8, 16>::from(x));
println!("a,b,c,d: {a:02x?}, {b:02x?}, {c:02x?}, {d:02x?}");
}
// copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer)
let ab = simd::merge_lo_hi_m128(a, b);
// copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer)
let cd = simd::merge_lo_hi_m128(c, d);
if_trace_simd! {
unroll!(let [ab, cd] => |x| Simd::<u8, 16>::from(x));
println!("ab,cd: {ab:x?}, {cd:x?}");
}
let ab1: simd::arch::__m256i;
let cd1: simd::arch::__m256i;
std::arch::asm!("vpunpcklbw {:y}, {:y}, {:y}", lateout(ymm_reg) ab1, in(ymm_reg) ab, in(ymm_reg) cd);
std::arch::asm!("vpunpckhbw {:y}, {:y}, {:y}", lateout(ymm_reg) cd1, in(ymm_reg) ab, in(ymm_reg) cd);
//let abcd = simd::merge_m128_m256(util::cast(ab1), util::cast(cd1));
//let abcd = ab1;
let abcd: simd::arch::__m256i;
core::arch::asm!("vinserti128 {:y}, {:y}, {:x}, 0x1", lateout(ymm_reg) abcd, in(ymm_reg) ab1, in(ymm_reg) cd1, options(pure, nomem, preserves_flags, nostack));
//core::arch::asm!("vinserti128 {:y}, {:y}, {:x}, 0x1", lateout(ymm_reg) abcd, in(ymm_reg) ab1, in(ymm_reg) cd1);
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
//let abcd = simd::merge_m128_m256(ab, cd);
if_trace_simd! {
let abcd: Simd<u8, DIGIT_BATCH_SIZE> = abcd.into();
println!("abcd: {abcd:x?}");
}
// HA! there's an undocumented requirement for the dest to be 32-byte aligned.
//assert_eq!((ptr.cast::<u8>() as usize) & 32 - 1, 0);
core::arch::asm!("vmovdqa [{}], {}", in(reg) dst as *mut i8, in(ymm_reg) abcd, options(preserves_flags, nostack));
dst = dst.add(DIGIT_BATCH_SIZE);
i += DIGIT_BATCH_SIZE;
src = src.add(WIDE_BATCH_SIZE);
}
}
while i < ub {
unsafe {
let b = *src;
*dst = MaybeUninit::new(*select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b >> 4) as usize));
dst = dst.add(1);
*dst = MaybeUninit::new(*select!($UPPER ? HEX_CHARS_UPPER : HEX_CHARS_LOWER).get_unchecked((b & 0x0f) as usize));
dst = dst.add(1);
i += 2;
src = src.add(1);
}
}
if_trace_simd! {
let slice: &[_] = $dst.as_ref();
match std::str::from_utf8(unsafe { &*(slice as *const [_] as *const [u8]) }) {
Ok(s) => {
println!("encoded: {s:?}");
}
Err(e) => {
println!("encoded corrupted utf8: {e}");
}
}
}
}};
}
macro_rules! define_encode_str {
($name:ident$(<$N:ident>)?($in:ty) $(where $( $where:tt )+)?) => {
fn $name$(<const $N: usize>)?(src: $in) -> String $( where $( $where )+ )?;
};
}
macro_rules! impl_encode_str {
($name:ident$(<$N:ident>)?($in:ty) => $impl:ident (|$bytes:ident| $into_vec:expr) $(where $( $where:tt )+)?) => {
#[inline]
fn $name$(<const $N: usize>)?(src: $in) -> String $( where $( $where )+ )? {
let $bytes = Self::$impl(src);
unsafe { String::from_utf8_unchecked($into_vec) }
}
};
}
pub trait Encode {
@ -59,19 +399,29 @@ pub trait Encode {
/// Encodes the unsized input on the heap.
fn enc_slice(src: &[u8]) -> Box<[u8]>;
define_encode_str!(enc_str_sized<N>(&[u8; N]) where [u8; N * 2]:);
define_encode_str!(enc_str_sized_heap<N>(&[u8; N]) where [u8; N * 2]:);
define_encode_str!(enc_str_slice(&[u8]));
}
pub struct Encoder<const UPPER: bool = false>;
#[repr(align(32))]
struct Aligned32<T>(T);
impl<const UPPER: bool> Encode for Encoder<UPPER> {
#[inline]
fn enc_sized<const N: usize>(src: &[u8; N]) -> [u8; N * 2]
where
[u8; N * 2]:,
{
let mut buf = MaybeUninit::uninit_array();
common_impl!(UPPER, src, buf);
unsafe { MaybeUninit::array_assume_init(buf) }
// SAFETY: `Aligned32` has no initialization in and of itself, nor does an array of `MaybeUninit`
let mut buf =
unsafe { MaybeUninit::<Aligned32<[MaybeUninit<_>; N * 2]>>::uninit().assume_init() };
let buf1 = &mut buf.0;
common_impl!(UPPER, src, buf1);
unsafe { MaybeUninit::array_assume_init(buf.0) }
}
#[inline]
@ -79,22 +429,30 @@ impl<const UPPER: bool> Encode for Encoder<UPPER> {
where
[u8; N * 2]:,
{
let mut buf: Box<[MaybeUninit<u8>; N * 2]> = unsafe { Box::new_uninit().assume_init() };
let mut buf: Box<[MaybeUninit<u8>; N * 2]> =
unsafe { util::alloc_aligned_box::<_, REQUIRED_ALIGNMENT>() };
common_impl!(UPPER, src, buf);
unsafe { Box::from_raw(Box::into_raw(buf).cast()) }
}
#[inline]
fn enc_slice(src: &[u8]) -> Box<[u8]> {
let mut buf = Box::new_uninit_slice(src.len() * 2);
let mut buf: Box<[MaybeUninit<u8>]> =
unsafe { util::alloc_aligned_box_slice::<_, REQUIRED_ALIGNMENT>(src.len() * 2) };
common_impl!(UPPER, src, buf);
unsafe { Box::<[_]>::assume_init(buf) }
}
impl_encode_str!(enc_str_sized<N>(&[u8; N]) => enc_sized (|bytes| bytes.into()) where [u8; N * 2]:);
impl_encode_str!(enc_str_sized_heap<N>(&[u8; N]) => enc_sized_heap (|bytes| {
Vec::from_raw_parts(Box::into_raw(bytes) as *mut u8, N * 2, N * 2)
}) where [u8; N * 2]:);
impl_encode_str!(enc_str_slice(&[u8]) => enc_slice (|bytes| Vec::from(bytes)));
}
impl<const UPPER: bool> Encoder<UPPER> {
#[inline]
pub fn enc_const<const N: usize>(src: &[u8; N]) -> [u8; N * 2]
pub fn enc_const<const N: usize>(mut src: &[u8; N]) -> [u8; N * 2]
where
[u8; N * 2]:,
{
@ -110,14 +468,54 @@ mod test {
use crate::test::*;
#[test]
fn test_nbl_to_ascii() {
for i in 0..16 {
let a = nbl_to_ascii::<false>(i);
let b = HEX_CHARS_LOWER[i as usize];
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
let a = nbl_to_ascii::<true>(i);
let b = HEX_CHARS_UPPER[i as usize];
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
}
}
#[test]
fn test_nbl_wide_to_ascii() {
for i in 0..16 {
let a = nbl_wide_to_ascii::<false>(i);
let b = HEX_CHARS_LOWER[i as usize] as u16;
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
let a = nbl_wide_to_ascii::<true>(i);
let b = HEX_CHARS_UPPER[i as usize] as u16;
assert_eq!(a, b, "({i}) {a:08b} != {b:08b}");
}
}
#[test]
fn test_byte_to_ascii() {
for i in 0..=255 {
let a = byte_to_ascii::<false>(i);
let b = HEX_BYTES_LOWER[i as usize];
assert_eq!(a, b, "({i}) {a:016b} != {b:016b}");
let a = byte_to_ascii::<true>(i);
let b = HEX_BYTES_UPPER[i as usize];
assert_eq!(a, b, "({i}) {a:016b} != {b:016b}");
}
}
macro_rules! for_each_sample {
($name:ident, |$sb:ident, $shb:ident| $expr:expr) => {
($name:ident, |$ss:pat_param, $shs:pat_param, $sb:pat_param, $shb:pat_param| $expr:expr) => {
#[test]
fn $name() {
let $ss = STR;
let $shs = HEX_STR;
let $sb = BYTES;
let $shb = HEX_BYTES;
$expr;
let $ss = LONG_STR;
let $shs = LONG_HEX_STR;
let $sb = LONG_BYTES;
let $shb = LONG_HEX_BYTES;
$expr;
@ -125,10 +523,26 @@ mod test {
};
}
type Enc = Encoder::<true>;
type Enc = Encoder<true>;
for_each_sample!(enc_const, |b, hb| assert_eq!(Enc::enc_const(b), *hb));
for_each_sample!(enc_sized, |b, hb| assert_eq!(Enc::enc_sized(b), *hb));
for_each_sample!(enc_sized_heap, |b, hb| assert_eq!(Enc::enc_sized_heap(b), Box::new(*hb)));
for_each_sample!(enc_slice, |b, hb| assert_eq!(Enc::enc_slice(b), (*hb).into_iter().collect::<Vec<_>>().into_boxed_slice()));
for_each_sample!(enc_const, |_, _, b, hb| assert_eq!(Enc::enc_const(b), *hb));
for_each_sample!(enc_sized, |_, _, b, hb| assert_eq!(Enc::enc_sized(b), *hb));
for_each_sample!(enc_sized_heap, |_, _, b, hb| assert_eq!(
Enc::enc_sized_heap(b),
Box::new(*hb)
));
for_each_sample!(enc_slice, |_, _, b, hb| assert_eq!(
Enc::enc_slice(b),
(*hb).into_iter().collect::<Vec<_>>().into_boxed_slice()
));
for_each_sample!(enc_str_sized, |_, hs, b, _| assert_eq!(Enc::enc_str_sized(b), hs.to_owned()));
for_each_sample!(enc_str_sized_heap, |_, hs, b, _| assert_eq!(
Enc::enc_str_sized_heap(b),
hs.to_owned()
));
for_each_sample!(enc_str_slice, |_, hs, b, _| assert_eq!(
Enc::enc_str_slice(b),
hs.to_owned()
));
}

View File

@ -13,6 +13,7 @@
#![feature(const_maybe_uninit_array_assume_init)]
#![feature(const_maybe_uninit_uninit_array)]
#![cfg_attr(feature = "alloc", feature(new_uninit))]
#![feature(stdsimd)]
#![feature(portable_simd)]
// ignores warning about `generic_const_exprs`
#![allow(incomplete_features)]
@ -20,10 +21,14 @@
#[cfg(feature = "alloc")]
extern crate alloc;
pub(crate) mod simd;
pub mod simd;
pub(crate) mod util;
pub(crate) mod test;
#[doc(hidden)]
#[cfg(any(test, feature = "test"))]
pub mod test;
pub mod dec;
pub mod enc;
pub mod prelude;

12
src/prelude.rs Normal file
View File

@ -0,0 +1,12 @@
pub(crate) use crate::simd;
pub(crate) use crate::util;
pub(crate) use simd::arch;
pub(crate) use simd::if_trace_simd;
pub(crate) use simd::SimdBitwise;
pub(crate) use simd::SimdLoad;
pub(crate) use simd::SimdSplat;
pub(crate) use simd::SimdTestAnd;
pub(crate) use util::array_op;
pub(crate) use util::unroll;

View File

@ -1,11 +1,16 @@
use core::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount};
use crate::util::cast;
use crate::util;
use util::cast;
const W_128: usize = 128 / 8;
const W_256: usize = 256 / 8;
const W_512: usize = 512 / 8;
/// The value which cause `vpshufb` to write 0 instead of indexing.
const SHUF_0: u8 = 0b1000_0000;
#[cfg(target_arch = "aarch64")]
pub use core::arch::aarch64 as arch;
#[cfg(target_arch = "arm")]
@ -19,10 +24,21 @@ pub use core::arch::x86 as arch;
#[cfg(target_arch = "x86_64")]
pub use core::arch::x86_64 as arch;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub use arch::{__m128i, __m256i, __m512i};
// use the maximum batch size that would be supported by AVX-512
//pub(crate) const SIMD_WIDTH: usize = 512;
pub const SIMD_WIDTH: usize = 256;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
/// The batch size used for the hex digits.
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
pub const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
#[macro_export]
macro_rules! __if_trace_simd {
($( $tt:tt )*) => {
@ -51,7 +67,7 @@ where
macro_rules! specialized {
($(
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
@ -60,7 +76,7 @@ macro_rules! specialized {
)+) => {$(
#[allow(dead_code)]
#[inline(always)]
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $args )*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
@ -70,9 +86,9 @@ macro_rules! specialized {
}
}
)+};
($trait:ident for $ty:ty;
($LANES:ident =>
$(
fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
fn $name:ident$(<[$( $generics:tt )+]>)?($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
@ -80,9 +96,9 @@ macro_rules! specialized {
}
)+
) => {
impl<const $LANES: usize> $trait for $ty {$(
$(
#[inline(always)]
fn $name$(<$( $generics )+>)?($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
fn $name$(<$( $generics )+>)?($( $args )*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
@ -91,85 +107,192 @@ macro_rules! specialized {
),+
}
}
)+
};
($trait:ident for $ty:ty;
$(
fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)+
) => {
impl<const $LANES: usize> $trait for $ty {$(
specialized! { LANES =>
fn $name$(<[$( $generics:tt )+]>)?($( $args )*) -> $rt $(where [$( $where )*])? {
$(
$width $( if $cfg )? => $impl
),+
}
}
)+}
};
}
macro_rules! set1 {
($arch:ident, $inst:ident, $vec:ident, $reg:ident, $n:ident) => {{
let out: core::arch::$arch::$vec;
core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack));
macro_rules! set1_short {
($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{
// WOW. this is 12% faster than broadcast on xmm. Seems not so much on ymm (more expensive
// array init?).
const O_LANES: usize = std::mem::size_of::<$vec>() / std::mem::size_of::<$n_ty>();
util::cast::<_, $vec>(Simd::<$n_ty, O_LANES>::from_array([$n; O_LANES]))
//let out: $vec;
//core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, __m128i>($n), options(pure, nomem, preserves_flags, nostack));
//out
}};
}
macro_rules! set1_long {
($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{
//const O_LANES: usize = std::mem::size_of::<$vec>() / std::mem::size_of::<$n_ty>();
//util::cast::<_, $vec>(Simd::<$n_ty, O_LANES>::from_array([$n; O_LANES]))
let out: $vec;
core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, __m128i>($n), options(pure, nomem, preserves_flags, nostack));
out
}};
}
pub trait DoubleWidth {
type Output;
}
macro_rules! impl_double_width {
($($in:ty => $out:ty),+) => {
$(
impl DoubleWidth for $in {
type Output = $out;
}
)+
}
}
impl_double_width!(u8 => u16, u16 => u32, u32 => u64);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
impl_double_width!(u64 => arch::__m128i, arch::__m128i => arch::__m256i, arch::__m256i => arch::__m512i);
pub trait SimdSplat<const LANES: usize> {
type Output;
fn splat(self) -> Self::Output;
fn splat_zero() -> Self::Output;
}
pub trait SimdLoad: Sized {
fn load_128(self) -> arch::__m128i;
#[inline(always)]
fn load_256(self) -> arch::__m256i {
unsafe { util::cast(self.load_128()) }
}
#[inline(always)]
fn load_512(self) -> arch::__m512i {
unsafe { util::cast(self.load_128()) }
}
}
macro_rules! impl_load {
(($in:ident, $($in_fmt:ident)?) $in_v:ident) => {{
unsafe {
let out: _;
core::arch::asm!(concat!("vmovq {}, {:", $(stringify!($in_fmt), )? "}"), lateout(xmm_reg) out, in($in) $in_v, options(pure, nomem, preserves_flags, nostack));
out
}};
}
}}
}
macro_rules! simd_load_fallback {
($ty:ty, $vec_ty:ty, $self:ident) => {{
let mut a = std::mem::MaybeUninit::uninit_array();
a[0] = std::mem::MaybeUninit::new($self);
for i in 1..(std::mem::size_of::<$vec_ty>() / std::mem::size_of::<$ty>()) {
a[i] = std::mem::MaybeUninit::new(0);
}
let a = unsafe { std::mem::MaybeUninit::array_assume_init(a) };
Simd::from_array(a).into()
}};
}
macro_rules! impl_ops {
($(
$ty:ty {
/// The appropriate register type.
reg: $reg:ident,
reg_fmt: $($reg_fmt:ident)?,
broadcast: $broadcast:ident,
set1: [$set1_128:ident, $set1_256:ident, $set1_512:ident]
$(,)?
}
)+) => {$(
impl SimdLoad for $ty {
#[inline(always)]
fn load_128(self) -> arch::__m128i {
#[cfg(target_feature = "avx")]
{ impl_load!(($reg, $($reg_fmt)?) self) }
#[cfg(not(target_feature = "avx"))]
simd_load_fallback!($ty, arch::__m128i, self)
}
}
impl<const LANES: usize> SimdSplat<LANES> for $ty where LaneCount<LANES>: SupportedLaneCount {
type Output = Simd<$ty, LANES>;
specialized! { LANES =>
fn splat(self) -> Self::Output {
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_short!($broadcast, __m128i, xmm_reg, self: $ty)) },
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_long!($broadcast, __m256i, ymm_reg, self: $ty)) },
W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512bw") => unsafe { cast(set1_long!($broadcast, __m512i, zmm_reg, self: $ty)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(arch::$set1_128(self as i8)) },
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(arch::$set1_256(self as i8)) },
// I can't actually test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512f") => unsafe { cast(arch::$set1_512(self as i8)) },
_ => Simd::splat(self),
}
fn splat_zero() -> Self::Output {
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2") => unsafe { cast(arch::_mm_setzero_si128()) },
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx") => unsafe { cast(arch::_mm256_setzero_si256()) },
W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512f") => unsafe { cast(arch::_mm512_setzero_si512()) },
_ => Self::splat(0),
}
}
}
)+};
}
impl_ops! {
u8 {
reg: reg_byte,
reg_fmt: ,
broadcast: vpbroadcastb,
set1: [_mm_set1_epi8, _mm256_set1_epi8, _mm512_set1_epi8]
}
specialized! {
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_n<LANES>(n: u8) -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
_ => Simd::splat(n),
u16 {
reg: reg,
reg_fmt: e,
broadcast: vpbroadcastw,
set1: [_mm_set1_epi16, _mm256_set1_epi16, _mm512_set1_epi16]
}
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_u16<LANES>(n: u16) -> Simd<u16, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi16(n as i16)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi16(n as i16)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi16(n as i16)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi16(n as i16)) },
_ => Simd::splat(n),
u32 {
reg: reg,
reg_fmt: e,
broadcast: vpbroadcastd,
set1: [_mm_set1_epi32, _mm256_set1_epi32, _mm512_set1_epi32]
}
pub fn splat_0u8<LANES>() -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => unsafe { crate::util::cast(splat_n(0)) },
}
pub fn splat_0u16<LANES>() -> Simd<u16, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => unsafe { crate::util::cast(splat_n(0)) },
u64 {
reg: reg,
reg_fmt: r,
broadcast: vpbroadcastq,
set1: [_mm_set1_epi64, _mm256_set1_epi64, _mm512_set1_epi64]
}
}
@ -235,15 +358,6 @@ macro_rules! __swizzle {
pub use __swizzle as swizzle;
pub use __swizzle_indices as swizzle_indices;
#[inline(always)]
pub fn load_u64_m128(v: u64) -> arch::__m128i {
unsafe {
let out: _;
core::arch::asm!("vmovq {}, {}", lateout(xmm_reg) out, in(reg) v, options(pure, nomem, preserves_flags, nostack));
out
}
}
/// Merges the low halves of `a` and `b` into a single register like `ab`.
#[inline(always)]
pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
@ -286,7 +400,7 @@ macro_rules! extract_lohi_bytes {
#[inline(always)]
pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i {
extract_lohi_bytes!(([0xffu16; 8], vpand, vpackuswb), v)
extract_lohi_bytes!(([0x00ffu16; 8], vpand, vpackuswb), v)
}
#[inline(always)]
@ -314,13 +428,18 @@ pub trait SimdBitwise {
fn and(self, rhs: Self) -> Self;
}
pub trait SimdWiden {
/// Widens the lower bytes by spacing and zero-extending them to N bytes.
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self;
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m128i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack));
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = lateout(reg_byte) out, options(pure, nomem, nostack));
core::mem::transmute(out)
}
}
@ -332,7 +451,7 @@ impl SimdTestAnd for arch::__m256i {
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack));
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = lateout(reg_byte) out, options(pure, nomem, nostack));
core::mem::transmute(out)
}
}
@ -349,7 +468,7 @@ impl SimdBitwise for arch::__m128i {
arch::_mm_or_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = lateout(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
@ -362,7 +481,7 @@ impl SimdBitwise for arch::__m128i {
arch::_mm_and_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = lateout(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
@ -378,7 +497,7 @@ impl SimdBitwise for arch::__m256i {
arch::_mm256_or_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = lateout(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
@ -391,9 +510,192 @@ impl SimdBitwise for arch::__m256i {
arch::_mm256_and_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = lateout(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
}
macro_rules! widen_128_impl {
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:literal, $offset:literal) => {unsafe {
const LEN: usize = std::mem::size_of::<$Self>();
const MASK_MOD_EQ: usize = if $be { 0 } else { $shift };
// TODO: evaluate whether there is a more efficient approach for the ignored bytes.
const INDICES: [u8; LEN] = util::array_op!(gen[LEN] |i| {
if (i + 1) % ($shift + 1) == MASK_MOD_EQ {
((i as u8) >> $shift) + $offset
} else {
SHUF_0
}
});
const INDICES_SIMD: $Self = unsafe { util::cast(INDICES) };
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| {
// if (i + 1) % ($shift + 1) == MASK_MOD_EQ { 0xff } else { 0x00 }
//});
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| if ((i + 1) % ($shift + 1) == $offset) { 0xff } else { 0x00 });
let out: _;
core::arch::asm!("vpshufb {out}, {in}, {indices}", in = in($reg) $self, indices = in($reg) INDICES_SIMD, out = lateout($reg) out, options(pure, nomem, preserves_flags, nostack));
if_trace_simd! {
println!("Offset: {}", $offset);
println!("Indices: {:?}", INDICES);
//println!("MASK: {MASK:?}");
}
out
//SimdBitwise::and(out, util::cast(MASK))
}};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset_in:ident [$( $offset:literal ),+]) => {
match $offset_in {
$( $offset => if $be { widen_128_impl!($self, $Self, $reg, $shift, true, $offset) } else { widen_128_impl!($self, $Self, $reg, $shift, false, $offset) }, )+
_ => panic!("Unsupported widen O value"),
}
};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset:ident) => {
widen_128_impl!($self, $Self, $reg, $shift, $be, $offset [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
};
}
macro_rules! widen_256_impl {
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:literal, $offset:literal) => {unsafe {
const LEN: usize = std::mem::size_of::<$Self>();
const MASK_MOD_EQ: usize = if $be { 0 } else { $shift };
// TODO: evaluate whether there is a more efficient approach for the ignored bytes.
const INDICES_LO: [u8; LEN] = util::array_op!(gen[LEN] |i| {
let i = i % LEN;
if (i + 1) % ($shift + 1) == MASK_MOD_EQ {
((i as u8) >> $shift) + $offset
} else {
SHUF_0
}
});
const INDICES_HI: [u8; LEN] = util::array_op!(gen[LEN] |i| {
let i = (i % (LEN / 2)) + LEN / 2;
if (i + 1) % ($shift + 1) == MASK_MOD_EQ {
((i as u8) >> $shift) + $offset
} else {
SHUF_0
}
});
const INDICES_LO_SIMD: $Self = unsafe { util::cast(INDICES_LO) };
const INDICES_HI_SIMD: $Self = unsafe { util::cast(INDICES_HI) };
const HI_MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| if i < LEN / 2 { 0x00 } else { 0xff });
const HI_MASK_SIMD: $Self = unsafe { util::cast(HI_MASK) };
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| {
// if (i + 1) % ($shift + 1) == MASK_MOD_EQ { 0xff } else { 0x00 }
//});
//const MASK: [u8; LEN] = util::array_op!(gen[LEN] |i| if ((i + 1) % ($shift + 1) == $offset) { 0xff } else { 0x00 });
let lo: $Self;
core::arch::asm!("vpshufb {out:x}, {in:x}, {indices:x}", in = in($reg) $self, indices = in($reg) INDICES_LO_SIMD, out = lateout($reg) lo, options(pure, nomem, preserves_flags, nostack));
let hi: $Self;
core::arch::asm!("vpshufb {out:x}, {in:x}, {indices:x}", in = in($reg) $self, indices = in($reg) INDICES_HI_SIMD, out = lateout($reg) hi, options(pure, nomem, preserves_flags, nostack));
let out = lo.or(hi.and(HI_MASK_SIMD));
if_trace_simd! {
println!("Offset: {}", $offset);
println!("Indices: {:?},{:?}", INDICES_LO, INDICES_HI);
//println!("MASK: {MASK:?}");
}
out
//SimdBitwise::and(out, util::cast(MASK))
}};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset_in:ident [$( $offset:literal ),+]) => {
match $offset_in {
$( $offset => if $be { widen_256_impl!($self, $Self, $reg, $shift, true, $offset) } else { widen_256_impl!($self, $Self, $reg, $shift, false, $offset) }, )+
_ => panic!("Unsupported widen O value"),
}
};
($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset:ident) => {
widen_256_impl!($self, $Self, $reg, $shift, $be, $offset [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! __simd__shr_64 {
($n:literal, ($in_reg:ident) $in:expr) => {{
let out: _;
std::arch::asm!(concat!("vpsrlq {dst}, {src}, ", $n), src = in($in_reg) $in, dst = lateout($in_reg) out);
out
}};
}
pub use __simd__shr_64 as shr_64;
/*impl SimdWiden for arch::__m128i {
#[inline(always)]
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self {
match N {
1 => self,
2 => widen_128_impl!(self, arch::__m128i, xmm_reg, 1, BE, O),
4 => widen_128_impl!(self, arch::__m128i, xmm_reg, 2, BE, O),
8 => widen_128_impl!(self, arch::__m128i, xmm_reg, 3, BE, O),
_ => panic!("Unsupported widen N value"),
}
}
}
impl SimdWiden for arch::__m256i {
#[inline(always)]
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self {
match N {
1 => self,
2 => widen_256_impl!(self, arch::__m256i, ymm_reg, 1, BE, O),
4 => widen_256_impl!(self, arch::__m256i, ymm_reg, 2, BE, O),
8 => widen_256_impl!(self, arch::__m256i, ymm_reg, 3, BE, O),
16 => widen_256_impl!(self, arch::__m256i, ymm_reg, 4, BE, O),
_ => panic!("Unsupported widen N value"),
}
}
}*/
#[cfg(test)]
mod test {
use super::*;
/*#[test]
fn test_widen() {
const SAMPLE_IN: [u8; 32] = [
1, 2, 3, 4, 5, 6, 7, 8,
1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16,
9, 10, 11, 12, 13, 14, 15, 16
];
const SAMPLE_OUT_LO_BE: [u8; 32] = [
0, 1, 0, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 7, 0, 8,
0, 1, 0, 2, 0, 3, 0, 4,
0, 5, 0, 6, 0, 7, 0, 8
];
const SAMPLE_OUT_HI_BE: [u8; 32] = [
0, 9, 0, 10, 0, 11, 0, 12,
0, 13, 0, 14, 0, 15, 0, 16,
0, 9, 0, 10, 0, 11, 0, 12,
0, 13, 0, 14, 0, 15, 0, 16
];
let sample: arch::__m256i = Simd::from_array(SAMPLE_IN).into();
let widened = sample.widen::<2, 0, true>();
//assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_LO_BE));
let widened = sample.widen::<2, 16, true>();
assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_HI_BE));
const SAMPLE_OUT_LO_LE: [u8; 32] = [
1, 0, 2, 0, 3, 0, 4, 0,
5, 0, 6, 0, 7, 0, 8, 0,
1, 0, 2, 0, 3, 0, 4, 0,
5, 0, 6, 0, 7, 0, 8, 0
];
const SAMPLE_OUT_HI_LE: [u8; 32] = [
9, 0, 10, 0, 11, 0, 12, 0,
13, 0, 14, 0, 15, 0, 16, 0,
9, 0, 10, 0, 11, 0, 12, 0,
13, 0, 14, 0, 15, 0, 16, 0
];
let sample: arch::__m256i = Simd::from_array(SAMPLE_IN).into();
let widened = sample.widen::<2, 0, false>();
//assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_LO_LE));
let widened = sample.widen::<2, 16, false>();
assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_HI_LE));
}*/
}

View File

@ -36,3 +36,22 @@ pub const INVALID_SAMPLES: &[&str] = &[
"446F6C6F72756D2064697374696E6374696G20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
"446F6C6F72756D2064697374696E637469GF20757420656172756D2071756964656D2064697374696E6374696F206E65636573736974617469627573207175616D2E20536974207072616573656E7469756D20666163657265207065727370696369617469732069757265206175742073756E742065742065742E20416469706973636920656E696D20726572756D20696C6C756D206574206F666669636961206E697369207265637573616E6461652E20566974616520646F6C6F7269627573207574207175696120656120756E646520636F6E73657175756E747572207175616520696C6C756D2E204964206569757320686172756D206573742E20496E76656E746F726520697073756D20757420736974207574207665726F20636F6E73656374657475722E",
];
#[doc(hidden)]
#[macro_export]
macro_rules! __test__name {
($group:literal, $f:literal) => {
concat!("[", $group, "] - ", $f)
};
($group:expr, $f:literal) => {
std::boxed::Box::leak(format!(name!("{}", $f), $group).into_boxed_str())
};
($group:expr, $f:expr) => {
std::boxed::Box::leak(format!(name!("{}", "{}"), $group, $f).into_boxed_str())
};
($group:literal, $f:expr) => {
std::boxed::Box::leak(format!(name!($group, "{}"), $f).into_boxed_str())
};
}
pub use __test__name as name;

View File

@ -2,7 +2,7 @@
#[macro_export]
macro_rules! __array_op {
(gen[$len:expr] |$i:pat_param| $val:expr) => {{
let mut out = std::mem::MaybeUninit::uninit_array();
let mut out = std::mem::MaybeUninit::uninit_array::<$len>();
let mut i = 0;
while i < $len {
out[i] = std::mem::MaybeUninit::new(match i {
@ -10,7 +10,10 @@ macro_rules! __array_op {
});
i += 1;
}
unsafe { std::mem::MaybeUninit::array_assume_init(out) }
#[allow(unused_unsafe)]
unsafe {
std::mem::MaybeUninit::array_assume_init(out)
}
}};
(map[$len:expr, $src:expr] |$i:pat_param, $s:pat_param| $val:expr) => {{
$crate::util::array_op!(
@ -27,6 +30,11 @@ macro_rules! __array_op {
pub use __array_op as array_op;
#[inline]
pub const fn cast_u8_u32<const N: usize>(arr: [u8; N]) -> [u32; N] {
array_op!(map[N, arr] |_, v| v as u32)
}
#[doc(hidden)]
#[macro_export]
macro_rules! __defer_impl {
@ -111,6 +119,71 @@ pub const fn align_up_to<const N: usize>(n: usize) -> usize {
return (n + (N - 1)) >> shift << shift;
}
#[inline(always)]
pub unsafe fn alloc_aligned_box<T, const ALIGN: usize>() -> Box<T> {
let ptr = alloc_aligned(
std::mem::size_of::<T>(),
std::cmp::max(std::mem::align_of::<T>(), ALIGN),
);
Box::from_raw(ptr as *mut T)
}
#[inline(always)]
pub unsafe fn alloc_aligned_box_slice<T, const ALIGN: usize>(len: usize) -> Box<[T]> {
let size = std::cmp::max(std::mem::size_of::<T>(), std::mem::align_of::<T>());
let ptr = alloc_aligned(size * len, std::cmp::max(std::mem::align_of::<T>(), ALIGN));
Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr as *mut T, len))
}
#[inline(always)]
pub unsafe fn alloc_aligned(len: usize, align: usize) -> *mut u8 {
let layout = std::alloc::Layout::from_size_align_unchecked(len, align);
std::alloc::alloc(layout)
}
#[macro_export]
macro_rules! __util__unroll {
(let [$( $x:ident ),+] => |$y:pat_param| $expr:expr) => {
$crate::util::unroll!(let [$( $x: ($x) ),+] => |$y| $expr);
};
(let [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
$crate::util::unroll!(@ [$] [$( $id: ($( $x ),+) ),+] => |$( $y ),+| $($expr)+)
};
(@ [$dollar:tt] [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
macro_rules! __unrolled {
($dollar id:ident, $dollar ( $dollar z:tt ),+) => {
#[allow(unused_parens)]
let ($( $y ),+) = ($dollar ( $dollar z ),+);
let $dollar id = $($expr)+;
};
}
$( __unrolled!($id, $( $x ),+); )+
//$(
// let ($( $y ),+) = ($( $x ),+);
// $expr;
//)+
};
([$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
$crate::util::unroll!(@ [$] [$( ($( $x ),+) ),+] => |$( $y ),+| $($expr)+)
};
(@ [$dollar:tt] [$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
macro_rules! __unrolled {
($dollar ( $dollar z:tt ),+) => {
#[allow(unused_parens)]
let ($( $y ),+) = ($dollar ( $dollar z ),+);
$($expr)+;
};
}
$( __unrolled!($( $x ),+); )+
//$(
// let ($( $y ),+) = ($( $x ),+);
// $expr;
//)+
};
}
pub use __util__unroll as unroll;
#[cfg(test)]
mod test {
use super::*;
@ -123,4 +196,33 @@ mod test {
assert_eq!(align_down_to::<16>(15), 0);
assert_eq!(align_down_to::<16>(17), 16);
}
#[test]
pub fn test_array_op_gen() {
assert_eq!(array_op!(gen[4] | i | i), [0, 1, 2, 3]);
assert_eq!(
array_op!(gen[16] | i | i),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
);
assert_eq!(array_op!(gen[4] | i | i as u8), [0u8, 1, 2, 3]);
assert_eq!(
array_op!(gen[16] | i | i as u8),
[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
);
const A: [u8; 4] = array_op!(gen[4] | i | i as u8);
const B: [u8; 16] = array_op!(gen[16] | i | i as u8);
assert_eq!(A, [0u8, 1, 2, 3]);
assert_eq!(B, [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
const LEN: usize = std::mem::size_of::<std::arch::x86_64::__m256i>();
const INDICES: [u8; LEN] = array_op!(gen[LEN] | i | (i as u8) >> 1);
assert_eq!(
INDICES,
[
0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12,
13, 13, 14, 14, 15, 15
]
);
}
}