Vectorized implementation is now faster than scalar

- For all tested sizes
- Excluding micro benches (still comes very close)
This commit is contained in:
Michael Pfaff 2022-11-01 15:01:30 -04:00
parent 84d8ef7748
commit 6eb5c5e46a
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
3 changed files with 317 additions and 251 deletions

View File

@ -27,6 +27,9 @@ use core::simd::*;
use alloc::{boxed::Box, vec::Vec};
use simd::SimdTestAnd as _;
use simd::SimdBitwise as _;
use util::array_op;
// use the maximum batch size that would be supported by AVX-512
//pub const SIMD_WIDTH: usize = 512;
@ -43,74 +46,24 @@ const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
macro_rules! if_trace_simd {
($( $tt:tt )*) => {
// disabled
//$( $tt )*
//{ $( $tt )* }
};
}
const VALIDATE: bool = true;
#[inline]
const fn alternating_mask<const N: usize>(first_bias: bool) -> [bool; N] {
let mut mask = [false; N];
let mut i = 0;
if first_bias {
while i < N / 2 {
mask[i * 2] = true;
i += 1;
}
} else {
while i < N / 2 {
mask[i * 2 + 1] = true;
i += 1;
}
}
mask
}
#[inline]
const fn msb_lsb_indices<const N: usize>() -> [usize; N] {
if N % 2 != 0 {
panic!("Illegal N");
}
let mut indices = [0; N];
let mut i = 0;
while i < N / 2 {
indices[i] = i * 2;
indices[N / 2 + i] = i * 2 + 1;
i += 1;
}
indices
}
#[inline]
const fn alternating_indices<const N: usize>(first_bias: bool) -> [usize; N] {
let mut indices = [0; N];
let mut i = 0;
if first_bias {
while i < N {
indices[i] = i * 2;
i += 1;
}
array_op!(gen[N] |i| i * 2)
} else {
while i < N {
indices[i] = i * 2 + 1;
i += 1;
}
array_op!(gen[N] |i| i * 2 + 1)
}
indices
}
#[inline]
const fn cast_usize_u8<const N: usize>(arr: [usize; N]) -> [u8; N] {
let mut arr1 = [0u8; N];
let mut i = 0;
while i < arr.len() {
arr1[i] = arr[i] as u8;
i += 1;
}
arr1
const fn cast_u8_u32<const N: usize>(arr: [u8; N]) -> [u32; N] {
array_op!(map[N, arr] |_, v| v as u32)
}
const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true);
@ -121,9 +74,7 @@ pub const INVALID_BIT: u8 = 0b1000_0000;
pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000;
const ASCII_DIGITS: [u8; 256] = {
let mut digits = [0u8; 256];
let mut i = u8::MIN;
while i < u8::MAX {
array_op!(gen[256] |i| {
const DIGIT_MIN: u8 = '0' as u8;
const DIGIT_MAX: u8 = '9' as u8;
const LOWER_MIN: u8 = 'a' as u8;
@ -131,27 +82,17 @@ const ASCII_DIGITS: [u8; 256] = {
const UPPER_MIN: u8 = 'A' as u8;
const UPPER_MAX: u8 = 'F' as u8;
digits[i as usize] = match i {
let i = i as u8;
match i {
DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN,
LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN,
UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN,
_ => INVALID_BIT,
};
i += 1;
}
digits
}
})
};
const __ASCII_DIGITS_SIMD: [u32; 256] = {
let mut digits = [0u32; 256];
let mut i = u8::MIN;
while i < u8::MAX {
digits[i as usize] = ASCII_DIGITS[i as usize] as u32;
i += 1;
}
digits
};
const __ASCII_DIGITS_SIMD: [u32; 256] = cast_u8_u32(ASCII_DIGITS);
const ASCII_DIGITS_SIMD: *const i32 = &__ASCII_DIGITS_SIMD as *const u32 as *const i32;
@ -194,7 +135,7 @@ where
&ASCII_DIGITS,
Mask::splat(true),
ascii.cast(),
simd::splat_0::<LANES>(),
simd::splat_0u8::<LANES>(),
)
}
}
@ -255,7 +196,7 @@ impl HexByteSimdDecoder for HexByteDecoderA {
fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option<Simd<u8, WIDE_BATCH_SIZE>> {
let hex_digits = hex_digit_simd::<DIGIT_BATCH_SIZE>(Simd::from_array(hi_los));
if ((hex_digits & simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT))
.simd_ne(simd::splat_0::<DIGIT_BATCH_SIZE>()))
.simd_ne(simd::splat_0u8::<DIGIT_BATCH_SIZE>()))
.any()
{
return None;
@ -347,46 +288,131 @@ macro_rules! hex_digits_simd_inline {
let c = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, c, 4);
let d = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, d, 4);
let a = Simd::<u32, GATHER_BATCH_SIZE>::from(a).cast::<u8>();
let b = Simd::<u32, GATHER_BATCH_SIZE>::from(b).cast::<u8>();
let c = Simd::<u32, GATHER_BATCH_SIZE>::from(c).cast::<u8>();
let d = Simd::<u32, GATHER_BATCH_SIZE>::from(d).cast::<u8>();
let a = Simd::<u32, GATHER_BATCH_SIZE>::from(a).cast::<u8>();
let b = Simd::<u32, GATHER_BATCH_SIZE>::from(b).cast::<u8>();
let c = Simd::<u32, GATHER_BATCH_SIZE>::from(c).cast::<u8>();
let d = Simd::<u32, GATHER_BATCH_SIZE>::from(d).cast::<u8>();
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
if_trace_simd! {
println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
// load the 64-bit integers into registers
let a = simd::load_u64_m128(util::cast(a));
let b = simd::load_u64_m128(util::cast(b));
let c = simd::load_u64_m128(util::cast(c));
let d = simd::load_u64_m128(util::cast(d));
// load the 64-bit integers into registers
let a = simd::load_u64_m128(util::cast(a));
let b = simd::load_u64_m128(util::cast(b));
let c = simd::load_u64_m128(util::cast(c));
let d = simd::load_u64_m128(util::cast(d));
{
let a = Simd::<u8, 16>::from(a);
let b = Simd::<u8, 16>::from(b);
let c = Simd::<u8, 16>::from(c);
let d = Simd::<u8, 16>::from(d);
if_trace_simd! {
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
}
if_trace_simd! {
let a = Simd::<u8, 16>::from(a);
let b = Simd::<u8, 16>::from(b);
let c = Simd::<u8, 16>::from(c);
let d = Simd::<u8, 16>::from(d);
println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}");
}
// copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer)
let ab = simd::merge_low_hi_m128(a, b);
// copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer)
let cd = simd::merge_low_hi_m128(c, d);
// copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer)
let ab = simd::merge_lo_hi_m128(a, b);
// copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer)
let cd = simd::merge_lo_hi_m128(c, d);
{
let ab = Simd::<u8, 16>::from(ab);
let cd = Simd::<u8, 16>::from(cd);
if_trace_simd! {
println!("ab,cd: {ab:x?}, {cd:x?}");
}
}
if_trace_simd! {
let ab = Simd::<u8, 16>::from(ab);
let cd = Simd::<u8, 16>::from(cd);
println!("ab,cd: {ab:x?}, {cd:x?}");
}
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
simd::merge_m128_m256(ab, cd)
// merge the xmm0 and xmm1 (ymm1) registers into ymm0
let abcd = simd::merge_m128_m256(ab, cd);
if_trace_simd! {
let abcd: Simd<u8, DIGIT_BATCH_SIZE> = abcd.into();
println!("abcd: {abcd:x?}");
}
abcd
}};
}
macro_rules! merge_hex_digits_into_bytes_inline {
($hex_digits:ident) => {{
//let hex_digits = Simd::<u8, DIGIT_BATCH_SIZE>::from($hex_digits);
//let msb = simd_swizzle!(hex_digits, MSB_INDICES);
//let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
let msb = simd::extract_lo_bytes($hex_digits);
let lsb = simd::extract_hi_bytes($hex_digits);
let msb1: simd::arch::__m128i;
unsafe { std::arch::asm!("vpsllw {dst}, {src}, 4", src = in(xmm_reg) msb, dst = lateout(xmm_reg) msb1) };
if_trace_simd! {
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
println!("msb1: {msb1:x?}");
}
//let msb2: simd::arch::__m128i;
//unsafe { std::arch::asm!("vpand {dst}, {src}, 0x0000_0000_0000_0000_0000_0000_0000_f0f0", src = in(xmm_reg) msb1, dst = lateout(xmm_reg) msb2) };
let msb2 = msb1.and(Simd::from_array([0xf0f0u16; WIDE_BATCH_SIZE / 2]).into());
let b = msb2.or(lsb);
if_trace_simd! {
let msb: Simd<u8, WIDE_BATCH_SIZE> = msb.into();
let msb1: Simd<u8, WIDE_BATCH_SIZE> = msb1.into();
let msb2: Simd<u8, WIDE_BATCH_SIZE> = msb2.into();
let lsb: Simd<u8, WIDE_BATCH_SIZE> = lsb.into();
let b: Simd<u8, WIDE_BATCH_SIZE> = b.into();
println!("| Packed | Msb | <<4 | & | Lsb | Bytes | |");
Simd::<u8, DIGIT_BATCH_SIZE>::from($hex_digits)
.to_array()
.chunks(2)
.zip(msb.to_array())
.zip(msb1.to_array())
.zip(msb2.to_array())
.zip(lsb.to_array())
.zip(b.to_array())
.for_each(|(((((chunk, msb), msb1), msb2), lsb), b)| {
println!(
"| {chunk:02x?} | {msb:x?} | {msb1:x?} | {msb2:x?} | {lsb:x?} | {b:02x?} | {ok} |",
chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16),
ok = if chunk[0] == msb && chunk[1] == lsb {
'✓'
} else {
'✗'
}
);
});
}
b
/*let msb = simd::extract_lo_bytes($hex_digits);
let lsb = simd::extract_hi_bytes($hex_digits);
let msb: Simd<u8, WIDE_BATCH_SIZE> = msb.into();
let lsb: Simd<u8, WIDE_BATCH_SIZE> = lsb.into();
if_trace_simd! {
println!("msb: {msb:x?}");
println!("lsb: {lsb:x?}");
println!("| Packed | Msb | Lsb | |");
Simd::<u8, DIGIT_BATCH_SIZE>::from(hex_digits)
.to_array()
.chunks(2)
.zip(msb.to_array())
.zip(lsb.to_array())
.for_each(|((chunk, msb), lsb)| {
println!(
"| {chunk:02x?} | {msb:x?} | {lsb:x?} | {ok} |",
chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16),
ok = if chunk[0] == msb && chunk[1] == lsb {
'✓'
} else {
'✗'
}
);
});
}
//simd::merge_lo_hi_m128(msb, lsb)
simd::arch::__m128i::from((msb << simd::splat_n::<WIDE_BATCH_SIZE>(4)) | lsb)*/
}};
}
@ -406,16 +432,8 @@ impl HexByteSimdDecoder for HexByteDecoderB {
if hex_digits.test_and_non_zero(simd::splat_n::<DIGIT_BATCH_SIZE>(INVALID_BIT).into()) {
return None;
}
let hex_digits = Simd::<u8, DIGIT_BATCH_SIZE>::from(hex_digits);
let msb = simd_swizzle!(hex_digits, MSB_INDICES);
let lsb = simd_swizzle!(hex_digits, LSB_INDICES);
let mut v = Simd::from_array([0u8; WIDE_BATCH_SIZE]);
for (i, v) in v.as_mut_array().iter_mut().enumerate() {
let hi = unsafe { *msb.as_array().get_unchecked(i) };
let lo = unsafe { *lsb.as_array().get_unchecked(i) };
*v = (hi << 4) | lo;
}
Some(v)
Some(merge_hex_digits_into_bytes_inline!(hex_digits).into())
}
}
@ -424,94 +442,16 @@ pub type HBD = HexByteDecoderB;
pub mod conv {
use core::simd::{LaneCount, Simd, SupportedLaneCount};
/*trait Size {
const N: usize;
}
macro_rules! size_impl {
($ident:ident($size:expr)) => {
struct $ident;
impl Size for $ident {
const N: usize = $size;
}
};
($ident:ident<$size:ty>) => {
size_impl!($ident(core::mem::size_of::<$size>()));
};
}
struct SizeMul<const N: usize, T>(core::marker::PhantomData<T>);
impl<const N: usize, T: Size> Size for SizeMul<N, T> {
const N: usize = T::N * N;
}
size_impl!(SizeU8<u8>);
size_impl!(SizeU16<u16>);
size_impl!(SizeU32<u32>);
size_impl!(SizeU64<u64>);
trait SizeOf {
type Size: Size;
//const SIZE: usize;
}
//impl<T> SizeOf for T {
// const SIZE: usize = core::mem::size_of::<T>();
//}
macro_rules! size_of_impl {
($type:ty = $size:ident) => {
impl SizeOf for $type {
type Size = $size;
}
};
}
size_of_impl!(u8 = SizeU8);
size_of_impl!([u8; 2] = SizeU16);
size_of_impl!(u16 = SizeU16);
size_of_impl!(u32 = SizeU32);
size_of_impl!(u64 = SizeU64);*/
#[allow(non_camel_case_types, non_snake_case)]
union u8_u16<const N_u16: usize>
where
[u8; N_u16 * 2]:,
{
u8: [u8; N_u16 * 2],
u16: [u16; N_u16],
}
#[allow(non_camel_case_types, non_snake_case)]
union u8x2_u8<const N_u8x2: usize>
where
[u8; N_u8x2 * 2]:,
{
u8x2: [[u8; 2]; N_u8x2],
u8: [u8; N_u8x2 * 2],
}
#[allow(non_camel_case_types, non_snake_case)]
union SimdU8_SimdU16<const N_U16: usize>
where
LaneCount<{ N_U16 * 2 }>: SupportedLaneCount,
LaneCount<N_U16>: SupportedLaneCount,
{
SimdU8: Simd<u8, { N_U16 * 2 }>,
SimdU16: Simd<u16, N_U16>,
}
use crate::util;
#[inline(always)]
pub const fn u8_to_u16<const N_OUT: usize>(a: [u8; N_OUT * 2]) -> [u16; N_OUT] {
unsafe { u8_u16 { u8: a }.u16 }
unsafe { util::cast(a) }
}
#[inline(always)]
pub const fn u8x2_to_u8<const N_IN: usize>(a: [[u8; 2]; N_IN]) -> [u8; N_IN * 2] {
unsafe { u8x2_u8 { u8x2: a }.u8 }
unsafe { util::cast(a) }
}
#[inline(always)]
@ -522,7 +462,7 @@ pub mod conv {
LaneCount<{ N_OUT * 2 }>: SupportedLaneCount,
LaneCount<N_OUT>: SupportedLaneCount,
{
unsafe { SimdU8_SimdU16 { SimdU8: a }.SimdU16 }
unsafe { util::cast(a) }
}
}
@ -598,7 +538,7 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
if VECTORED {
use simd::arch;
let mut bad: arch::__m256i = simd::splat_0().into();
let mut bad: arch::__m256i = simd::splat_0u8().into();
let mut i = 0;
while i < util::align_down_to::<DIGIT_BATCH_SIZE>(ascii.len()) {
let hex_digits = unsafe {
@ -611,43 +551,11 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit<u8>]) -> bo
core::arch::asm!("vpor {bad}, {digits}, {bad}", bad = inout(ymm_reg) bad, digits = in(ymm_reg) hex_digits, options(pure, nomem, preserves_flags, nostack));
}
}
let hex_digits: Simd<u8, DIGIT_BATCH_SIZE> = hex_digits.into();
if_trace_simd! {
println!("hex_digits: {hex_digits:x?}");
}
let hex_digits: arch::__m256i = hex_digits.into();
let msb = simd::extract_lo_bytes(hex_digits);
let lsb = simd::extract_hi_bytes(hex_digits);
let msb: Simd<u8, WIDE_BATCH_SIZE> = msb.into();
let lsb: Simd<u8, WIDE_BATCH_SIZE> = lsb.into();
if_trace_simd! {
println!("msb: {msb:x?}");
println!("lsb: {lsb:x?}");
println!("| Packed | Msb | Lsb | |");
Simd::<u8, DIGIT_BATCH_SIZE>::from(hex_digits)
.to_array()
.chunks(2)
.zip(msb.to_array())
.zip(lsb.to_array())
.for_each(|((chunk, msb), lsb)| {
println!(
"| {chunk:02x?} | {msb:x?} | {lsb:x?} | {ok} |",
chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16),
ok = if chunk[0] == msb && chunk[1] == lsb {
'✓'
} else {
'✗'
}
);
});
}
let buf = (msb << simd::splat_n::<WIDE_BATCH_SIZE>(4)) | lsb;
let buf = merge_hex_digits_into_bytes_inline!(hex_digits);
//let buf: arch::__m128i = unsafe { util::cast(buf) };
let buf: arch::__m128i = buf.into();
//let buf: arch::__m128i = buf.into();
unsafe {
// vmovaps xmm0, xmmword ptr [rsi]

View File

@ -1,4 +1,4 @@
use core::simd::{LaneCount, Simd, SupportedLaneCount};
use core::simd::{LaneCount, Simd, SupportedLaneCount, SimdElement};
use crate::util::cast;
@ -19,15 +19,28 @@ pub use core::arch::x86 as arch;
#[cfg(target_arch = "x86_64")]
pub use core::arch::x86_64 as arch;
pub trait IsSimd {
type Lane;
const LANES: usize;
}
impl<T, const LANES: usize> IsSimd for Simd<T, LANES> where LaneCount<LANES>: SupportedLaneCount, T: SimdElement {
type Lane = T;
const LANES: usize = LANES;
}
macro_rules! specialized {
($(
$vis:vis fn $name:ident<$LANES:ident$(, $( $generics:tt )+)?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)+) => {$(
#[allow(dead_code)]
#[inline(always)]
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
@ -39,12 +52,35 @@ macro_rules! specialized {
}
}
)+};
($trait:ident for $ty:ty;
$(
fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)+
) => {
impl<const $LANES: usize> $trait for $ty {$(
#[inline(always)]
fn $name$(<$( $generics )+>)?($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
}
}
)+}
};
}
macro_rules! set1 {
($arch:ident, $vec:ident, $reg:ident, $n:ident) => {{
($arch:ident, $inst:ident, $vec:ident, $reg:ident, $n:ident) => {{
let out: core::arch::$arch::$vec;
core::arch::asm!("vpbroadcastb {}, {}", lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack));
core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack));
out
}};
}
@ -54,10 +90,10 @@ specialized! {
pub fn splat_n<LANES>(n: u8) -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m256i, ymm_reg, n)) },
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
@ -71,7 +107,28 @@ specialized! {
_ => Simd::splat(n),
}
pub fn splat_0<LANES>() -> Simd<u8, LANES> where [
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_u16<LANES>(n: u16) -> Simd<u16, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi16(n as i16)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi16(n as i16)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi16(n as i16)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi16(n as i16)) },
_ => Simd::splat(n),
}
pub fn splat_0u8<LANES>() -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
@ -81,7 +138,20 @@ specialized! {
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => splat_n(0),
_ => unsafe { crate::util::cast(splat_n(0)) },
}
pub fn splat_0u16<LANES>() -> Simd<u16, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => unsafe { crate::util::cast(splat_n(0)) },
}
}
@ -156,8 +226,9 @@ pub fn load_u64_m128(v: u64) -> arch::__m128i {
}
}
/// Merges the low halves of `a` and `b` into a single register like `ab`.
#[inline(always)]
pub fn merge_low_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
unsafe {
// xmm0 = xmm1[0],xmm0[0]
let out: _;
@ -217,6 +288,14 @@ pub trait SimdTestAnd {
fn test_and_non_zero(self, mask: Self) -> bool;
}
pub trait SimdBitwise {
/// Returns the bitwise OR of `self` and `rhs`.
fn or(self, rhs: Self) -> Self;
/// Returns the bitwise AND of `self` and `rhs`.
fn and(self, rhs: Self) -> Self;
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m128i {
#[inline(always)]
@ -240,3 +319,63 @@ impl SimdTestAnd for arch::__m256i {
}
}
}
const USE_BITWISE_INTRINSICS: bool = true;
#[cfg(target_feature = "avx")]
impl SimdBitwise for arch::__m128i {
#[inline(always)]
fn or(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm_or_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
#[inline(always)]
fn and(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm_and_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
}
#[cfg(target_feature = "avx2")]
impl SimdBitwise for arch::__m256i {
#[inline(always)]
fn or(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm256_or_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
#[inline(always)]
fn and(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm256_and_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
}

View File

@ -1,3 +1,22 @@
#[doc(hidden)]
#[macro_export]
macro_rules! __array_op {
(gen[$len:expr] |$i:pat_param| $val:expr) => {{
let mut out = std::mem::MaybeUninit::uninit_array();
let mut i = 0;
while i < $len {
out[i] = std::mem::MaybeUninit::new(match i { $i => $val });
i += 1;
}
unsafe { std::mem::MaybeUninit::array_assume_init(out) }
}};
(map[$len:expr, $src:expr] |$i:pat_param, $s:pat_param| $val:expr) => {{
$crate::util::array_op!(gen[$len] |i| match i { $i => match $src[i] { $s => $val } })
}};
}
pub use __array_op as array_op;
#[doc(hidden)]
#[macro_export]
macro_rules! __defer_impl {
@ -19,10 +38,10 @@ pub use __defer_impl as defer_impl;
#[inline(always)]
#[cold]
pub fn cold() {}
pub const fn cold() {}
#[inline]
pub fn likely(b: bool) -> bool {
pub const fn likely(b: bool) -> bool {
if !b {
cold()
}
@ -30,7 +49,7 @@ pub fn likely(b: bool) -> bool {
}
#[inline]
pub fn unlikely(b: bool) -> bool {
pub const fn unlikely(b: bool) -> bool {
if b {
cold()
}
@ -40,7 +59,7 @@ pub fn unlikely(b: bool) -> bool {
/// Like transmute, but implemented via a union so that we can use it in situations where
/// transmute's "safety" restrictions are too strict and uninformed (i.e. we can prove it is safe).
#[inline(always)]
pub unsafe fn cast<A, B>(a: A) -> B {
pub const unsafe fn cast<A, B>(a: A) -> B {
union Cast<A, B> {
a: core::mem::ManuallyDrop<A>,
b: core::mem::ManuallyDrop<B>,