use core::simd::{LaneCount, Simd, SupportedLaneCount, SimdElement}; use crate::util::cast; const W_128: usize = 128 / 8; const W_256: usize = 256 / 8; const W_512: usize = 512 / 8; #[cfg(target_arch = "aarch64")] pub use core::arch::aarch64 as arch; #[cfg(target_arch = "arm")] pub use core::arch::arm as arch; #[cfg(target_arch = "wasm32")] pub use core::arch::wasm32 as arch; #[cfg(target_arch = "wasm64")] pub use core::arch::wasm64 as arch; #[cfg(target_arch = "x86")] pub use core::arch::x86 as arch; #[cfg(target_arch = "x86_64")] pub use core::arch::x86_64 as arch; // use the maximum batch size that would be supported by AVX-512 //pub(crate) const SIMD_WIDTH: usize = 512; pub const SIMD_WIDTH: usize = 256; #[macro_export] macro_rules! __if_trace_simd { ($( $tt:tt )*) => { // disabled //{ $( $tt )* } }; } pub use __if_trace_simd as if_trace_simd; pub trait IsSimd { type Lane; const LANES: usize; } impl IsSimd for Simd where LaneCount: SupportedLaneCount, T: SimdElement { type Lane = T; const LANES: usize = LANES; } macro_rules! specialized { ($( $vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { $( $width:pat_param $( if $cfg:meta )? => $impl:expr ),+ $(,)? } )+) => {$( #[allow(dead_code)] #[inline(always)] $vis fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? { // abusing const generics to specialize without the unsoundness of real specialization! match $LANES { $( $( #[cfg( $cfg )] )? $width => $impl ),+ } } )+}; ($trait:ident for $ty:ty; $( fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { $( $width:pat_param $( if $cfg:meta )? => $impl:expr ),+ $(,)? } )+ ) => { impl $trait for $ty {$( #[inline(always)] fn $name$(<$( $generics )+>)?($( $argn: $argt ),*) -> $rt $( where $( $where )* )? { // abusing const generics to specialize without the unsoundness of real specialization! match $LANES { $( $( #[cfg( $cfg )] )? $width => $impl ),+ } } )+} }; } macro_rules! set1 { ($arch:ident, $inst:ident, $vec:ident, $reg:ident, $n:ident) => {{ let out: core::arch::$arch::$vec; core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack)); out }}; } specialized! { // TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq pub fn splat_n(n: u8) -> Simd where [ LaneCount: SupportedLaneCount, ] { W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m128i, xmm_reg, n)) }, W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m128i, xmm_reg, n)) }, W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m256i, ymm_reg, n)) }, W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m256i, ymm_reg, n)) }, // these are *terrible*. They compile to a bunch of MOVs and SETs W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) }, W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) }, W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) }, W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) }, // I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above. W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) }, _ => Simd::splat(n), } // TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq pub fn splat_u16(n: u16) -> Simd where [ LaneCount: SupportedLaneCount, ] { W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m128i, xmm_reg, n)) }, W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m128i, xmm_reg, n)) }, W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m256i, ymm_reg, n)) }, W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m256i, ymm_reg, n)) }, // these are *terrible*. They compile to a bunch of MOVs and SETs W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi16(n as i16)) }, W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi16(n as i16)) }, W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi16(n as i16)) }, W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi16(n as i16)) }, // I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above. W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi16(n as i16)) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi16(n as i16)) }, _ => Simd::splat(n), } pub fn splat_0u8() -> Simd where [ LaneCount: SupportedLaneCount, ] { // these are fine, they are supposed to XOR themselves to zero out. W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) }, W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) }, W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) }, W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) }, W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) }, _ => unsafe { crate::util::cast(splat_n(0)) }, } pub fn splat_0u16() -> Simd where [ LaneCount: SupportedLaneCount, ] { // these are fine, they are supposed to XOR themselves to zero out. W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) }, W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) }, W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) }, W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) }, W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) }, _ => unsafe { crate::util::cast(splat_n(0)) }, } } /// Defines the indices used by [`swizzle`]. #[macro_export] macro_rules! __swizzle_indices { ($name:ident = [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { core::arch::global_asm!(concat!(".", stringify!($name), ":") $( , concat!("\n .byte ", stringify!($index)) )+ $( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )?); }; } #[macro_export] macro_rules! __swizzle { /*(xmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { $crate::simd::swizzle!(@ xmm_reg, $src, $dest, (xmmword) [$( $index ),+] $( , [$( $padding )+] )?) }; (ymm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { $crate::simd::swizzle!(@ ymm_reg, $src, $dest, (ymmword) [$( $index ),+] $( , [$( $padding )+] )?) }; (zmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { $crate::simd::swizzle!(@ zmm_reg, $src, $dest, (zmmword) [$( $index ),+] $( , [$( $padding )+] )?) };*/ (xmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { $crate::simd::swizzle!(@ xmm_reg, x, $src, $mode $dest, (xmmword) $indices) }; (ymm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { $crate::simd::swizzle!(@ ymm_reg, y, $src, $mode $dest, (ymmword) $indices) }; (zmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { $crate::simd::swizzle!(@ zmm_reg z, $src, $mode $dest, (zmmword) $indices) }; ($reg:ident, $src:expr, $dest:expr, mem $indices:expr) => { core::arch::asm!("vpshufb {}, {}, [{}]", in($reg) $src, lateout($reg) $dest, in(reg) $indices, options(readonly, preserves_flags, nostack)); }; ($reg:ident, $src:expr, $dest:expr, data($indices_reg:ident) $indices:expr) => { core::arch::asm!("vpshufb {}, {}, {}", in($reg) $src, lateout($reg) $dest, in($indices_reg) $indices, options(pure, nomem, preserves_flags, nostack)); }; //(@ $reg:ident, $src:expr, $dest:expr, ($indices_reg:ident) [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { (@ $reg:ident, $token:ident, $src:expr, $mode:ident $dest:expr, ($indices_reg:ident) $indices:ident) => { core::arch::asm!(concat!("vpshufb {:", stringify!($token), "}, {:", stringify!($token), "}, ", stringify!($indices_reg), " ptr [rip + .", stringify!($indices), "]"), $mode($reg) $dest, in($reg) $src, options(pure, nomem, preserves_flags, nostack)); // core::arch::asm!("2:" // $( , concat!("\n .byte ", stringify!($index)) )+ // $( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )? // , "\n3:\n", concat!(" vpshufb {}, {}, ", stringify!($indices_reg), " ptr [rip + 2b]"), in($reg) $src, lateout($reg) $dest) }; // ($src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { // $crate::simd::swizzle!(@ $src, $dest, [$( stringify!($index) ),+] $( , [$( "\n ", subst!($padding, ""), "zero 1" )+] )?) // }; // (@ $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:literal )+] )?) => { // core::arch::asm!(r#" // .indices:"#, // $( "\n .byte ", $index ),+ // $( $( $padding ),+ )? // r#" // lsb: // vpshufb {}, {}, xmmword ptr [rip + .indices] // "#, in(xmm_reg) $src, lateout(xmm_reg) $dest) // }; } pub use __swizzle as swizzle; pub use __swizzle_indices as swizzle_indices; #[inline(always)] pub fn load_u64_m128(v: u64) -> arch::__m128i { unsafe { let out: _; core::arch::asm!("vmovq {}, {}", lateout(xmm_reg) out, in(reg) v, options(pure, nomem, preserves_flags, nostack)); out } } /// Merges the low halves of `a` and `b` into a single register like `ab`. #[inline(always)] pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { unsafe { // xmm0 = xmm1[0],xmm0[0] let out: _; core::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack)); out } } /// The args are in little endian order (first arg is lowest order) #[inline(always)] pub fn merge_m128_m256(a: arch::__m128i, b: arch::__m128i) -> arch::__m256i { unsafe { let out: _; core::arch::asm!("vinserti128 {}, {:y}, {}, 0x1", lateout(ymm_reg) out, in(ymm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack)); out } } macro_rules! extract_lohi_bytes { (($mask:expr, $op12:ident, $op3:ident), $in:ident) => {{ const MASK: arch::__m128i = unsafe { core::mem::transmute($mask) }; unsafe { let out: _; core::arch::asm!( //concat!("vmovdqa {mask}, xmmword ptr [rip + .", stringify!($mask), "]"), "vextracti128 {inter}, {input:y}, 1", concat!(stringify!($op12), " {inter}, {inter}, {mask}"), concat!(stringify!($op12), " {output:x}, {input:x}, {mask}"), concat!(stringify!($op3), " {output:x}, {output:x}, {inter}"), mask = in(xmm_reg) MASK, input = in(ymm_reg) $in, output = lateout(xmm_reg) out, inter = out(xmm_reg) _, options(pure, nomem, preserves_flags, nostack) ); out } }}; } #[inline(always)] pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i { extract_lohi_bytes!(([0xffu16; 8], vpand, vpackuswb), v) } #[inline(always)] pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i { extract_lohi_bytes!( ( [0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0], vpshufb, vpunpcklqdq ), v ) } pub trait SimdTestAnd { /// Returns true if the result of the bitwise AND of `self` and `mask` is not all zero. fn test_and_non_zero(self, mask: Self) -> bool; } pub trait SimdBitwise { /// Returns the bitwise OR of `self` and `rhs`. fn or(self, rhs: Self) -> Self; /// Returns the bitwise AND of `self` and `rhs`. fn and(self, rhs: Self) -> Self; } #[cfg(target_feature = "avx")] impl SimdTestAnd for arch::__m128i { #[inline(always)] fn test_and_non_zero(self, mask: Self) -> bool { unsafe { let out: u8; core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack)); core::mem::transmute(out) } } } #[cfg(target_feature = "avx")] impl SimdTestAnd for arch::__m256i { #[inline(always)] fn test_and_non_zero(self, mask: Self) -> bool { unsafe { let out: u8; core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack)); core::mem::transmute(out) } } } const USE_BITWISE_INTRINSICS: bool = true; #[cfg(target_feature = "avx")] impl SimdBitwise for arch::__m128i { #[inline(always)] fn or(self, rhs: Self) -> Self { unsafe { if USE_BITWISE_INTRINSICS { arch::_mm_or_si128(self, rhs) } else { let out: _; core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack)); out } } } #[inline(always)] fn and(self, rhs: Self) -> Self { unsafe { if USE_BITWISE_INTRINSICS { arch::_mm_and_si128(self, rhs) } else { let out: _; core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack)); out } } } } #[cfg(target_feature = "avx2")] impl SimdBitwise for arch::__m256i { #[inline(always)] fn or(self, rhs: Self) -> Self { unsafe { if USE_BITWISE_INTRINSICS { arch::_mm256_or_si256(self, rhs) } else { let out: _; core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack)); out } } } #[inline(always)] fn and(self, rhs: Self) -> Self { unsafe { if USE_BITWISE_INTRINSICS { arch::_mm256_and_si256(self, rhs) } else { let out: _; core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack)); out } } } }