90 lines
3.6 KiB
Rust
90 lines
3.6 KiB
Rust
|
use std::simd::{LaneCount, Simd, SupportedLaneCount};
|
||
|
|
||
|
use crate::util::cast;
|
||
|
|
||
|
pub trait SimdSplatZero<const LANES: usize> {
|
||
|
fn splat_zero() -> Simd<u8, LANES>
|
||
|
where
|
||
|
LaneCount<LANES>: SupportedLaneCount;
|
||
|
}
|
||
|
|
||
|
pub trait SimdSplatN<const LANES: usize> {
|
||
|
fn splat_n(n: u8) -> Simd<u8, LANES>
|
||
|
where
|
||
|
LaneCount<LANES>: SupportedLaneCount;
|
||
|
}
|
||
|
|
||
|
pub struct SimdOps;
|
||
|
|
||
|
const W_128: usize = 128 / 8;
|
||
|
const W_256: usize = 256 / 8;
|
||
|
const W_512: usize = 512 / 8;
|
||
|
|
||
|
macro_rules! specialized {
|
||
|
($LANES:ident, $trait:ident {
|
||
|
$(
|
||
|
fn $name:ident($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
|
||
|
$(
|
||
|
$width:pat_param $( if $cfg:meta )? => $impl:expr
|
||
|
),+
|
||
|
$(,)?
|
||
|
}
|
||
|
)*
|
||
|
}) => {
|
||
|
impl<const $LANES: usize> $trait<$LANES> for SimdOps {
|
||
|
$(
|
||
|
#[inline(always)]
|
||
|
fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
|
||
|
// abusing const generics to specialize without the unsoundness of real specialization!
|
||
|
match $LANES {
|
||
|
$(
|
||
|
$( #[cfg( $cfg )] )?
|
||
|
$width => $impl
|
||
|
),+
|
||
|
}
|
||
|
}
|
||
|
)*
|
||
|
}
|
||
|
};
|
||
|
}
|
||
|
|
||
|
specialized!(LANES, SimdSplatN {
|
||
|
fn splat_n(n: u8) -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
|
||
|
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
|
||
|
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
|
||
|
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
|
||
|
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
|
||
|
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
|
||
|
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
|
||
|
_ => Simd::splat(n),
|
||
|
}
|
||
|
});
|
||
|
|
||
|
specialized!(LANES, SimdSplatZero {
|
||
|
fn splat_zero() -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
|
||
|
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
|
||
|
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
|
||
|
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
|
||
|
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
|
||
|
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
|
||
|
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
|
||
|
_ => <Self as SimdSplatN<LANES>>::splat_n(0),
|
||
|
}
|
||
|
});
|
||
|
|
||
|
#[inline(always)]
|
||
|
pub fn splat_0<const LANES: usize>() -> Simd<u8, LANES>
|
||
|
where
|
||
|
LaneCount<LANES>: SupportedLaneCount,
|
||
|
{
|
||
|
<SimdOps as SimdSplatZero<LANES>>::splat_zero()
|
||
|
}
|
||
|
|
||
|
#[inline(always)]
|
||
|
pub fn splat_n<const LANES: usize>(n: u8) -> Simd<u8, LANES>
|
||
|
where
|
||
|
LaneCount<LANES>: SupportedLaneCount,
|
||
|
{
|
||
|
<SimdOps as SimdSplatN<LANES>>::splat_n(n)
|
||
|
}
|