Expand API, fix some bugs, optimize

This commit is contained in:
Michael Pfaff 2022-12-16 00:00:12 -05:00
parent 32d9111232
commit 41172e3adb
Signed by: michael
GPG Key ID: CF402C4A012AA9D4
2 changed files with 197 additions and 25 deletions

View File

@ -1,5 +1,6 @@
pub use crate::simd::SimdBitwise;
pub use crate::simd::SimdLoad;
pub use crate::simd::SimdLoadArray;
pub use crate::simd::SimdSplat;
pub use crate::simd::SimdTestAnd;

View File

@ -28,18 +28,6 @@ pub use core::arch::x86_64 as arch;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub use arch::{__m128i, __m256i, __m512i};
// use the maximum batch size that would be supported by AVX-512
//pub(crate) const SIMD_WIDTH: usize = 512;
pub const SIMD_WIDTH: usize = 256;
/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error).
pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16;
/// The batch size used for the hex digits.
pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2;
pub const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4;
#[cfg(feature = "trace-simd")]
#[macro_export]
macro_rules! __if_trace_simd {
@ -69,7 +57,7 @@ where
const LANES: usize = LANES;
}
macro_rules! specialized {
/*macro_rules! specialized {
($(
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
@ -133,7 +121,32 @@ macro_rules! specialized {
}
)+}
};
}
}*/
macro_rules! specialized {
($ty:ty =>
$(
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)+
) => {$(
#[allow(dead_code)]
#[inline(always)]
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $args )*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES * core::mem::size_of::<$ty>() {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
}
}
)+};
}
macro_rules! set1_short {
($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{
@ -194,7 +207,7 @@ pub trait SimdSplat {
fn splat_zero<const LANES: usize>() -> Self::Output<LANES> where LaneCount<LANES>: SupportedLaneCount;
}
pub trait SimdLoad: Sized {
pub trait SimdLoad: Sized + SimdElement {
fn load_128(self) -> arch::__m128i;
#[inline(always)]
@ -208,6 +221,33 @@ pub trait SimdLoad: Sized {
}
}
pub trait SimdLoadArray: Sized + SimdElement {
fn load_array<const LANES: usize>(array: [Self; LANES]) -> Simd<Self, LANES> where LaneCount<LANES>: SupportedLaneCount;
}
impl SimdLoadArray for u64 {
specialized! { Self =>
fn load_array<LANES>(array: [Self; LANES]) -> Simd<Self, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx") => unsafe {
let mut out: __m128i;
core::arch::asm!("vmovq {}, {:r}", lateout(xmm_reg) out, in(reg) array[0]);
core::arch::asm!("vpinsrq {out}, {out}, {b:r}, 0x1", out = inout(xmm_reg) out, b = in(reg) array[1]);
cast(out)
},
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx") => unsafe {
let mut out_a: __m128i;
core::arch::asm!("vmovq {}, {:r}", lateout(xmm_reg) out_a, in(reg) array[0]);
core::arch::asm!("vpinsrq {out}, {out}, {b:r}, 0x1", out = inout(xmm_reg) out_a, b = in(reg) array[1]);
let mut out_b: __m128i;
core::arch::asm!("vmovq {}, {:r}", lateout(xmm_reg) out_b, in(reg) array[2]);
core::arch::asm!("vpinsrq {out}, {out}, {b:r}, 0x1", out = inout(xmm_reg) out_b, b = in(reg) array[3]);
cast(merge_m128_m256(out_a, out_b))
},
_ => Simd::from_array(array),
}
}
}
macro_rules! impl_load {
(($in:ident, $($in_fmt:ident)?) $in_v:ident) => {{
unsafe {
@ -254,7 +294,7 @@ macro_rules! impl_ops {
impl const SimdSplat for $ty {
type Output<const LANES: usize> = Simd<$ty, LANES> where LaneCount<LANES>: SupportedLaneCount;
specialized! {
specialized! { Self =>
fn splat<LANES>(self) -> Self::Output<LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_short!($broadcast, __m128i, xmm_reg, self: $ty)) },
W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_long!($broadcast, __m256i, ymm_reg, self: $ty)) },
@ -349,28 +389,28 @@ impl_ops! {
reg: reg_byte,
reg_fmt: ,
broadcast: vpbroadcastb,
set1: [_mm_set1_epi8, _mm256_set1_epi8, _mm512_set1_epi8]
set1: [_mm_set1_epi8, _mm256_set1_epi8, _mm512_set1_epi8],
}
u16 {
reg: reg,
reg_fmt: e,
broadcast: vpbroadcastw,
set1: [_mm_set1_epi16, _mm256_set1_epi16, _mm512_set1_epi16]
set1: [_mm_set1_epi16, _mm256_set1_epi16, _mm512_set1_epi16],
}
u32 {
reg: reg,
reg_fmt: e,
broadcast: vpbroadcastd,
set1: [_mm_set1_epi32, _mm256_set1_epi32, _mm512_set1_epi32]
set1: [_mm_set1_epi32, _mm256_set1_epi32, _mm512_set1_epi32],
}
u64 {
reg: reg,
reg_fmt: r,
broadcast: vpbroadcastq,
set1: [_mm_set1_epi64, _mm256_set1_epi64, _mm512_set1_epi64]
set1: [_mm_set1_epi64, _mm256_set1_epi64, _mm512_set1_epi64],
}
}
@ -436,17 +476,51 @@ macro_rules! __swizzle {
pub use __swizzle as swizzle;
pub use __swizzle_indices as swizzle_indices;
/// Merges the low halves of `a` and `b` into a single register like `ab`.
#[inline(always)]
pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
pub fn interleave_lo_bytes_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
unsafe {
let out: _;
core::arch::asm!("vpunpcklbw {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack));
out
}
}
#[inline(always)]
pub fn interleave_lo_quads_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
unsafe {
// xmm0 = xmm1[0],xmm0[0]
let out: _;
core::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack));
out
}
}
#[inline(always)]
pub fn interleave_lo_bytes_m256(a: arch::__m256i, b: arch::__m256i) -> arch::__m256i {
unsafe {
let out: _;
core::arch::asm!("vpunpcklbw {}, {}, {}", lateout(ymm_reg) out, in(ymm_reg) a, in(ymm_reg) b, options(pure, nomem, preserves_flags, nostack));
out
}
}
#[inline(always)]
pub fn interleave_lo_quads_m256(a: arch::__m256i, b: arch::__m256i) -> arch::__m256i {
unsafe {
let out: _;
core::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(ymm_reg) out, in(ymm_reg) a, in(ymm_reg) b, options(pure, nomem, preserves_flags, nostack));
out
}
}
/// Merges the low halves of `a` and `b` into a single register like `ab`.
#[inline(always)]
pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
// this works because vpunpcklqdq interleaves the low order **quadwords** (aka the entire lower
// half)
// xmm0 = xmm1[0],xmm0[0]
interleave_lo_quads_m128(a, b)
}
/// The args are in little endian order (first arg is lowest order)
#[inline(always)]
pub fn merge_m128_m256(a: arch::__m128i, b: arch::__m128i) -> arch::__m256i {
@ -484,12 +558,12 @@ macro_rules! extract_lohi_bytes {
}
#[inline(always)]
pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i {
pub fn extract_lo_bytes(v: __m256i) -> __m128i {
extract_lohi_bytes!(([0x00ffu16; 8], vpand, vpackuswb), v)
}
#[inline(always)]
pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i {
pub fn extract_hi_bytes(v: __m256i) -> __m128i {
extract_lohi_bytes!(
(
[0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0],
@ -518,6 +592,44 @@ pub trait SimdWiden {
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self;
}
#[inline(always)]
pub fn interleave_m64(a: __m128i, b: __m128i) -> __m128i {
interleave_lo_bytes_m128(a, b)
}
#[inline(always)]
pub fn interleave_m128(a: __m128i, b: __m128i) -> __m256i {
const INTERLEAVE_A: Simd<u8, 32> = Simd::from_array(util::array_op!(gen[32] |i| {
if i & 1 == 0 {
(i as u8) >> 1
} else {
0xff
}
}));
const INTERLEAVE_B: Simd<u8, 32> = Simd::from_array(util::array_op!(gen[32] |i| {
if i & 1 == 0 {
0xff
} else {
(i as u8) >> 1
}
}));
unsafe {
//let zero: __m128i = u8::splat_zero::<16>().into();
let spaced_a = arch::_mm256_cvtepu8_epi16(a);
//let spaced_a = interleave_lo_bytes_m256(a, zero);
//let spaced_b = interleave_lo_bytes_m256(zero, b);
//let spaced_b = arch::_mm256_cvtepu8_epi16(b);
//let a = merge_m128_m256(a, a);
let b = merge_m128_m256(b, b);
//let spaced_a = arch::_mm256_shuffle_epi8(a, INTERLEAVE_A.into());
let spaced_b = arch::_mm256_shuffle_epi8(b, INTERLEAVE_B.into());
//let spaced_b: arch::__m256i = shl!(64, 8, (ymm_reg) spaced_b);
//interleave_lo_bytes_m256(a, b)
//interleave_lo_bytes_m256(spaced_a, spaced_b)
spaced_a.or(spaced_b)
}
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m128i {
#[inline(always)]
@ -714,6 +826,27 @@ macro_rules! __simd__shr {
pub use __simd__shr as shr;
#[doc(hidden)]
#[macro_export]
macro_rules! __simd__shl {
($inst:ident, $n:literal, ($in_reg:ident) $in:expr) => {{
let out: _;
core::arch::asm!(concat!(stringify!($inst), " {dst}, {src}, ", $n), src = in($in_reg) $in, dst = lateout($in_reg) out);
out
}};
(16, $n:literal, ($in_reg:ident) $in:expr) => {
$crate::simd::shl!(vpsllw, $n, ($in_reg) $in)
};
(32, $n:literal, ($in_reg:ident) $in:expr) => {
$crate::simd::shl!(vpslld, $n, ($in_reg) $in)
};
(64, $n:literal, ($in_reg:ident) $in:expr) => {
$crate::simd::shl!(vpsllq, $n, ($in_reg) $in)
};
}
pub use __simd__shl as shl;
/*impl SimdWiden for arch::__m128i {
#[inline(always)]
fn widen<const N: usize, const O: usize, const BE: bool>(self) -> Self {
@ -743,6 +876,7 @@ impl SimdWiden for arch::__m256i {
#[cfg(test)]
mod test {
use crate::prelude::*;
use super::*;
#[test]
@ -789,6 +923,43 @@ mod test {
assert!(EQ_CONST, "constant evaluated splat_zero did not produce the expected result");
}
#[test]
fn test_interleave_out_128() {
const EXPECTED: [u8; 16] = array_op!(gen[16] |i| i as u8);
const A: [u8; 16] = array_op!(gen[16] |i| (i as u8) << 1);
const B: [u8; 16] = array_op!(gen[16] |i| ((i as u8) << 1) + 1);
let actual = interleave_lo_bytes_m128(Simd::from_array(A).into(), Simd::from_array(B).into());
assert_eq!(Simd::from(actual), Simd::from_array(EXPECTED));
}
#[test]
fn test_interleave_out_256() {
const EXPECTED: [u8; 32] = array_op!(gen[32] |i| i as u8);
const A: [u8; 16] = array_op!(gen[16] |i| (i as u8) << 1);
const B: [u8; 16] = array_op!(gen[16] |i| ((i as u8) << 1) + 1);
const A1: [u8; 32] = array_op!(gen[32] |i| (i as u8) << 1);
const B1: [u8; 32] = array_op!(gen[32] |i| ((i as u8) << 1) + 1);
let a = Simd::from_array(A).into();
let b = Simd::from_array(B).into();
//let a = merge_m128_m256(a, a);
//let b = merge_m128_m256(b, b);
let actual = interleave_m128(a, b);
assert_eq!(Simd::from(actual), Simd::from_array(EXPECTED));
}
#[test]
fn test_load_array() {
const ARRAY_U64_128: [u64; 2] = [46374, 5748782187];
const ARRAY_U64_256: [u64; 4] = [46374, 5748782187, 38548839, 91848923];
assert_eq!(u64::load_array(ARRAY_U64_128), Simd::from_array(ARRAY_U64_128));
assert_eq!(u64::load_array(ARRAY_U64_256), Simd::from_array(ARRAY_U64_256));
}
/*#[test]
fn test_widen() {
const SAMPLE_IN: [u8; 32] = [