use std::simd::{LaneCount, Simd, SupportedLaneCount}; use crate::util::cast; const W_128: usize = 128 / 8; const W_256: usize = 256 / 8; const W_512: usize = 512 / 8; #[cfg(target_arch = "aarch64")] use std::arch::aarch64 as arch; #[cfg(target_arch = "arm")] use std::arch::arm as arch; #[cfg(target_arch = "wasm32")] use std::arch::wasm32 as arch; #[cfg(target_arch = "wasm64")] use std::arch::wasm64 as arch; #[cfg(target_arch = "x86")] use std::arch::x86 as arch; #[cfg(target_arch = "x86_64")] use std::arch::x86_64 as arch; macro_rules! specialized { ($( $vis:vis fn $name:ident<$LANES:ident$(, $( $generics:tt )+)?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { $( $width:pat_param $( if $cfg:meta )? => $impl:expr ),+ $(,)? } )+) => {$( #[inline(always)] $vis fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? { // abusing const generics to specialize without the unsoundness of real specialization! match $LANES { $( $( #[cfg( $cfg )] )? $width => $impl ),+ } } )+}; } macro_rules! set1 { ($arch:ident, $vec:ident, $reg:ident, $n:ident) => {{ let out: std::arch::$arch::$vec; std::arch::asm!("vpbroadcastb {}, {}", lateout($reg) out, in(xmm_reg) cast::<_, std::arch::$arch::__m128i>($n)); out }}; } specialized! { // TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq pub fn splat_n(n: u8) -> Simd where [ LaneCount: SupportedLaneCount, ] { W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m128i, xmm_reg, n)) }, W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m128i, xmm_reg, n)) }, W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m256i, ymm_reg, n)) }, W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m256i, ymm_reg, n)) }, // these are *terrible*. They compile to a bunch of MOVs and SETs W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) }, W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) }, W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) }, W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) }, // I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above. W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) }, _ => Simd::splat(n), } pub fn splat_0() -> Simd where [ LaneCount: SupportedLaneCount, ] { // these are fine, they are supposed to XOR themselves to zero out. W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) }, W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) }, W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) }, W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) }, W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) }, _ => splat_n(0), } } /// Defines the indices used by [`swizzle`]. #[macro_export] macro_rules! __swizzle_indices { ($name:ident = [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { std::arch::global_asm!(concat!(".", stringify!($name), ":") $( , concat!("\n .byte ", stringify!($index)) )+ $( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )?); }; } #[macro_export] macro_rules! __swizzle { /*(xmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { $crate::simd::swizzle!(@ xmm_reg, $src, $dest, (xmmword) [$( $index ),+] $( , [$( $padding )+] )?) }; (ymm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { $crate::simd::swizzle!(@ ymm_reg, $src, $dest, (ymmword) [$( $index ),+] $( , [$( $padding )+] )?) }; (zmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { $crate::simd::swizzle!(@ zmm_reg, $src, $dest, (zmmword) [$( $index ),+] $( , [$( $padding )+] )?) };*/ (xmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { $crate::simd::swizzle!(@ xmm_reg, x, $src, $mode $dest, (xmmword) $indices) }; (ymm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { $crate::simd::swizzle!(@ ymm_reg, y, $src, $mode $dest, (ymmword) $indices) }; (zmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { $crate::simd::swizzle!(@ zmm_reg z, $src, $mode $dest, (zmmword) $indices) }; ($reg:ident, $src:expr, $dest:expr, mem $indices:expr) => { std::arch::asm!("vpshufb {}, {}, [{}]", in($reg) $src, lateout($reg) $dest, in(reg) $indices); }; ($reg:ident, $src:expr, $dest:expr, data($indices_reg:ident) $indices:expr) => { std::arch::asm!("vpshufb {}, {}, {}", in($reg) $src, lateout($reg) $dest, in($indices_reg) $indices); }; //(@ $reg:ident, $src:expr, $dest:expr, ($indices_reg:ident) [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { (@ $reg:ident, $token:ident, $src:expr, $mode:ident $dest:expr, ($indices_reg:ident) $indices:ident) => { std::arch::asm!(concat!("vpshufb {:", stringify!($token), "}, {:", stringify!($token), "}, ", stringify!($indices_reg), " ptr [rip + .", stringify!($indices), "]"), $mode($reg) $dest, in($reg) $src); // std::arch::asm!("2:" // $( , concat!("\n .byte ", stringify!($index)) )+ // $( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )? // , "\n3:\n", concat!(" vpshufb {}, {}, ", stringify!($indices_reg), " ptr [rip + 2b]"), in($reg) $src, lateout($reg) $dest) }; // ($src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { // $crate::simd::swizzle!(@ $src, $dest, [$( stringify!($index) ),+] $( , [$( "\n ", subst!($padding, ""), "zero 1" )+] )?) // }; // (@ $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:literal )+] )?) => { // std::arch::asm!(r#" // .indices:"#, // $( "\n .byte ", $index ),+ // $( $( $padding ),+ )? // r#" // lsb: // vpshufb {}, {}, xmmword ptr [rip + .indices] // "#, in(xmm_reg) $src, lateout(xmm_reg) $dest) // }; } pub use __swizzle as swizzle; pub use __swizzle_indices as swizzle_indices; #[inline(always)] pub fn load_u64_m128(v: u64) -> arch::__m128i { unsafe { let out: _; std::arch::asm!("vmovq {}, {}", lateout(xmm_reg) out, in(reg) v); out } } #[inline(always)] pub fn merge_low_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { unsafe { // xmm0 = xmm1[0],xmm0[0] let out: _; std::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b); out } } /// The args are in little endian order (first arg is lowest order) pub fn merge_m128_m256(a: arch::__m128i, b: arch::__m128i) -> arch::__m256i { unsafe { let out: _; std::arch::asm!("vinserti128 {}, {:y}, {}, 0x1", lateout(ymm_reg) out, in(ymm_reg) a, in(xmm_reg) b); out } } macro_rules! extract_lohi_bytes { (($mask:expr, $op12:ident, $op3:ident), $in:ident) => {{ const MASK: arch::__m128i = unsafe { std::mem::transmute($mask) }; unsafe { let out: _; std::arch::asm!( //concat!("vmovdqa {mask}, xmmword ptr [rip + .", stringify!($mask), "]"), "vextracti128 {inter}, {input:y}, 1", concat!(stringify!($op12), " {inter}, {inter}, {mask}"), concat!(stringify!($op12), " {output:x}, {input:x}, {mask}"), concat!(stringify!($op3), " {output:x}, {output:x}, {inter}"), mask = in(xmm_reg) MASK, input = in(ymm_reg) $in, output = lateout(xmm_reg) out, inter = out(xmm_reg) _); out } }}; } #[inline(always)] pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i { extract_lohi_bytes!(([0xffu16; 8], vpand, vpackuswb), v) } #[inline(always)] pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i { extract_lohi_bytes!(([0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0], vpshufb, vpunpcklqdq), v) }