From 41172e3adbde6de4727b1117a6e0cf631877cc99 Mon Sep 17 00:00:00 2001 From: Michael Pfaff Date: Fri, 16 Dec 2022 00:00:12 -0500 Subject: [PATCH] Expand API, fix some bugs, optimize --- src/prelude.rs | 1 + src/simd.rs | 221 +++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 197 insertions(+), 25 deletions(-) diff --git a/src/prelude.rs b/src/prelude.rs index ce6eabd..fd895ec 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,5 +1,6 @@ pub use crate::simd::SimdBitwise; pub use crate::simd::SimdLoad; +pub use crate::simd::SimdLoadArray; pub use crate::simd::SimdSplat; pub use crate::simd::SimdTestAnd; diff --git a/src/simd.rs b/src/simd.rs index b9b38b0..770932f 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -28,18 +28,6 @@ pub use core::arch::x86_64 as arch; #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] pub use arch::{__m128i, __m256i, __m512i}; -// use the maximum batch size that would be supported by AVX-512 -//pub(crate) const SIMD_WIDTH: usize = 512; -pub const SIMD_WIDTH: usize = 256; - -/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error). -pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16; - -/// The batch size used for the hex digits. -pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2; - -pub const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4; - #[cfg(feature = "trace-simd")] #[macro_export] macro_rules! __if_trace_simd { @@ -69,7 +57,7 @@ where const LANES: usize = LANES; } -macro_rules! specialized { +/*macro_rules! specialized { ($( $vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? { $( @@ -133,7 +121,32 @@ macro_rules! specialized { } )+} }; - } + }*/ + +macro_rules! specialized { + ($ty:ty => + $( + $vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? { + $( + $width:pat_param $( if $cfg:meta )? => $impl:expr + ),+ + $(,)? + } + )+ + ) => {$( + #[allow(dead_code)] + #[inline(always)] + $vis fn $name($( $args )*) -> $rt $( where $( $where )* )? { + // abusing const generics to specialize without the unsoundness of real specialization! + match $LANES * core::mem::size_of::<$ty>() { + $( + $( #[cfg( $cfg )] )? + $width => $impl + ),+ + } + } + )+}; +} macro_rules! set1_short { ($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{ @@ -194,7 +207,7 @@ pub trait SimdSplat { fn splat_zero() -> Self::Output where LaneCount: SupportedLaneCount; } -pub trait SimdLoad: Sized { +pub trait SimdLoad: Sized + SimdElement { fn load_128(self) -> arch::__m128i; #[inline(always)] @@ -208,6 +221,33 @@ pub trait SimdLoad: Sized { } } +pub trait SimdLoadArray: Sized + SimdElement { + fn load_array(array: [Self; LANES]) -> Simd where LaneCount: SupportedLaneCount; +} + +impl SimdLoadArray for u64 { + specialized! { Self => + fn load_array(array: [Self; LANES]) -> Simd where [LaneCount: SupportedLaneCount] { + W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx") => unsafe { + let mut out: __m128i; + core::arch::asm!("vmovq {}, {:r}", lateout(xmm_reg) out, in(reg) array[0]); + core::arch::asm!("vpinsrq {out}, {out}, {b:r}, 0x1", out = inout(xmm_reg) out, b = in(reg) array[1]); + cast(out) + }, + W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx") => unsafe { + let mut out_a: __m128i; + core::arch::asm!("vmovq {}, {:r}", lateout(xmm_reg) out_a, in(reg) array[0]); + core::arch::asm!("vpinsrq {out}, {out}, {b:r}, 0x1", out = inout(xmm_reg) out_a, b = in(reg) array[1]); + let mut out_b: __m128i; + core::arch::asm!("vmovq {}, {:r}", lateout(xmm_reg) out_b, in(reg) array[2]); + core::arch::asm!("vpinsrq {out}, {out}, {b:r}, 0x1", out = inout(xmm_reg) out_b, b = in(reg) array[3]); + cast(merge_m128_m256(out_a, out_b)) + }, + _ => Simd::from_array(array), + } + } +} + macro_rules! impl_load { (($in:ident, $($in_fmt:ident)?) $in_v:ident) => {{ unsafe { @@ -254,7 +294,7 @@ macro_rules! impl_ops { impl const SimdSplat for $ty { type Output = Simd<$ty, LANES> where LaneCount: SupportedLaneCount; - specialized! { + specialized! { Self => fn splat(self) -> Self::Output where [LaneCount: SupportedLaneCount] { W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_short!($broadcast, __m128i, xmm_reg, self: $ty)) }, W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_long!($broadcast, __m256i, ymm_reg, self: $ty)) }, @@ -349,28 +389,28 @@ impl_ops! { reg: reg_byte, reg_fmt: , broadcast: vpbroadcastb, - set1: [_mm_set1_epi8, _mm256_set1_epi8, _mm512_set1_epi8] + set1: [_mm_set1_epi8, _mm256_set1_epi8, _mm512_set1_epi8], } u16 { reg: reg, reg_fmt: e, broadcast: vpbroadcastw, - set1: [_mm_set1_epi16, _mm256_set1_epi16, _mm512_set1_epi16] + set1: [_mm_set1_epi16, _mm256_set1_epi16, _mm512_set1_epi16], } u32 { reg: reg, reg_fmt: e, broadcast: vpbroadcastd, - set1: [_mm_set1_epi32, _mm256_set1_epi32, _mm512_set1_epi32] + set1: [_mm_set1_epi32, _mm256_set1_epi32, _mm512_set1_epi32], } u64 { reg: reg, reg_fmt: r, broadcast: vpbroadcastq, - set1: [_mm_set1_epi64, _mm256_set1_epi64, _mm512_set1_epi64] + set1: [_mm_set1_epi64, _mm256_set1_epi64, _mm512_set1_epi64], } } @@ -436,17 +476,51 @@ macro_rules! __swizzle { pub use __swizzle as swizzle; pub use __swizzle_indices as swizzle_indices; -/// Merges the low halves of `a` and `b` into a single register like `ab`. #[inline(always)] -pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { +pub fn interleave_lo_bytes_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { + unsafe { + let out: _; + core::arch::asm!("vpunpcklbw {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack)); + out + } +} + +#[inline(always)] +pub fn interleave_lo_quads_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { unsafe { - // xmm0 = xmm1[0],xmm0[0] let out: _; core::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack)); out } } +#[inline(always)] +pub fn interleave_lo_bytes_m256(a: arch::__m256i, b: arch::__m256i) -> arch::__m256i { + unsafe { + let out: _; + core::arch::asm!("vpunpcklbw {}, {}, {}", lateout(ymm_reg) out, in(ymm_reg) a, in(ymm_reg) b, options(pure, nomem, preserves_flags, nostack)); + out + } +} + +#[inline(always)] +pub fn interleave_lo_quads_m256(a: arch::__m256i, b: arch::__m256i) -> arch::__m256i { + unsafe { + let out: _; + core::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(ymm_reg) out, in(ymm_reg) a, in(ymm_reg) b, options(pure, nomem, preserves_flags, nostack)); + out + } +} + +/// Merges the low halves of `a` and `b` into a single register like `ab`. +#[inline(always)] +pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { + // this works because vpunpcklqdq interleaves the low order **quadwords** (aka the entire lower + // half) + // xmm0 = xmm1[0],xmm0[0] + interleave_lo_quads_m128(a, b) +} + /// The args are in little endian order (first arg is lowest order) #[inline(always)] pub fn merge_m128_m256(a: arch::__m128i, b: arch::__m128i) -> arch::__m256i { @@ -484,12 +558,12 @@ macro_rules! extract_lohi_bytes { } #[inline(always)] -pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i { +pub fn extract_lo_bytes(v: __m256i) -> __m128i { extract_lohi_bytes!(([0x00ffu16; 8], vpand, vpackuswb), v) } #[inline(always)] -pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i { +pub fn extract_hi_bytes(v: __m256i) -> __m128i { extract_lohi_bytes!( ( [0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0], @@ -518,6 +592,44 @@ pub trait SimdWiden { fn widen(self) -> Self; } +#[inline(always)] +pub fn interleave_m64(a: __m128i, b: __m128i) -> __m128i { + interleave_lo_bytes_m128(a, b) +} + +#[inline(always)] +pub fn interleave_m128(a: __m128i, b: __m128i) -> __m256i { + const INTERLEAVE_A: Simd = Simd::from_array(util::array_op!(gen[32] |i| { + if i & 1 == 0 { + (i as u8) >> 1 + } else { + 0xff + } + })); + const INTERLEAVE_B: Simd = Simd::from_array(util::array_op!(gen[32] |i| { + if i & 1 == 0 { + 0xff + } else { + (i as u8) >> 1 + } + })); + unsafe { + //let zero: __m128i = u8::splat_zero::<16>().into(); + let spaced_a = arch::_mm256_cvtepu8_epi16(a); + //let spaced_a = interleave_lo_bytes_m256(a, zero); + //let spaced_b = interleave_lo_bytes_m256(zero, b); + //let spaced_b = arch::_mm256_cvtepu8_epi16(b); + //let a = merge_m128_m256(a, a); + let b = merge_m128_m256(b, b); + //let spaced_a = arch::_mm256_shuffle_epi8(a, INTERLEAVE_A.into()); + let spaced_b = arch::_mm256_shuffle_epi8(b, INTERLEAVE_B.into()); + //let spaced_b: arch::__m256i = shl!(64, 8, (ymm_reg) spaced_b); + //interleave_lo_bytes_m256(a, b) + //interleave_lo_bytes_m256(spaced_a, spaced_b) + spaced_a.or(spaced_b) + } +} + #[cfg(target_feature = "avx")] impl SimdTestAnd for arch::__m128i { #[inline(always)] @@ -714,6 +826,27 @@ macro_rules! __simd__shr { pub use __simd__shr as shr; +#[doc(hidden)] +#[macro_export] +macro_rules! __simd__shl { + ($inst:ident, $n:literal, ($in_reg:ident) $in:expr) => {{ + let out: _; + core::arch::asm!(concat!(stringify!($inst), " {dst}, {src}, ", $n), src = in($in_reg) $in, dst = lateout($in_reg) out); + out + }}; + (16, $n:literal, ($in_reg:ident) $in:expr) => { + $crate::simd::shl!(vpsllw, $n, ($in_reg) $in) + }; + (32, $n:literal, ($in_reg:ident) $in:expr) => { + $crate::simd::shl!(vpslld, $n, ($in_reg) $in) + }; + (64, $n:literal, ($in_reg:ident) $in:expr) => { + $crate::simd::shl!(vpsllq, $n, ($in_reg) $in) + }; +} + +pub use __simd__shl as shl; + /*impl SimdWiden for arch::__m128i { #[inline(always)] fn widen(self) -> Self { @@ -743,6 +876,7 @@ impl SimdWiden for arch::__m256i { #[cfg(test)] mod test { + use crate::prelude::*; use super::*; #[test] @@ -789,6 +923,43 @@ mod test { assert!(EQ_CONST, "constant evaluated splat_zero did not produce the expected result"); } + #[test] + fn test_interleave_out_128() { + const EXPECTED: [u8; 16] = array_op!(gen[16] |i| i as u8); + const A: [u8; 16] = array_op!(gen[16] |i| (i as u8) << 1); + const B: [u8; 16] = array_op!(gen[16] |i| ((i as u8) << 1) + 1); + + let actual = interleave_lo_bytes_m128(Simd::from_array(A).into(), Simd::from_array(B).into()); + assert_eq!(Simd::from(actual), Simd::from_array(EXPECTED)); + } + + #[test] + fn test_interleave_out_256() { + const EXPECTED: [u8; 32] = array_op!(gen[32] |i| i as u8); + const A: [u8; 16] = array_op!(gen[16] |i| (i as u8) << 1); + const B: [u8; 16] = array_op!(gen[16] |i| ((i as u8) << 1) + 1); + const A1: [u8; 32] = array_op!(gen[32] |i| (i as u8) << 1); + const B1: [u8; 32] = array_op!(gen[32] |i| ((i as u8) << 1) + 1); + + let a = Simd::from_array(A).into(); + let b = Simd::from_array(B).into(); + + //let a = merge_m128_m256(a, a); + //let b = merge_m128_m256(b, b); + + let actual = interleave_m128(a, b); + assert_eq!(Simd::from(actual), Simd::from_array(EXPECTED)); + } + + #[test] + fn test_load_array() { + const ARRAY_U64_128: [u64; 2] = [46374, 5748782187]; + const ARRAY_U64_256: [u64; 4] = [46374, 5748782187, 38548839, 91848923]; + + assert_eq!(u64::load_array(ARRAY_U64_128), Simd::from_array(ARRAY_U64_128)); + assert_eq!(u64::load_array(ARRAY_U64_256), Simd::from_array(ARRAY_U64_256)); + } + /*#[test] fn test_widen() { const SAMPLE_IN: [u8; 32] = [