use std::simd::{LaneCount, Simd, SupportedLaneCount}; use crate::util::cast; pub trait SimdSplatZero { fn splat_zero() -> Simd where LaneCount: SupportedLaneCount; } pub trait SimdSplatN { fn splat_n(n: u8) -> Simd where LaneCount: SupportedLaneCount; } pub struct SimdOps; const W_128: usize = 128 / 8; const W_256: usize = 256 / 8; const W_512: usize = 512 / 8; macro_rules! specialized { ($LANES:ident, $trait:ident { $( fn $name:ident($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { $( $width:pat_param $( if $cfg:meta )? => $impl:expr ),+ $(,)? } )* }) => { impl $trait<$LANES> for SimdOps { $( #[inline(always)] fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? { // abusing const generics to specialize without the unsoundness of real specialization! match $LANES { $( $( #[cfg( $cfg )] )? $width => $impl ),+ } } )* } }; } specialized!(LANES, SimdSplatN { fn splat_n(n: u8) -> Simd where [LaneCount: SupportedLaneCount] { W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) }, W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) }, W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) }, W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) }, W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) }, _ => Simd::splat(n), } }); specialized!(LANES, SimdSplatZero { fn splat_zero() -> Simd where [LaneCount: SupportedLaneCount] { W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) }, W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) }, W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) }, W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) }, W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) }, _ => >::splat_n(0), } }); #[inline(always)] pub fn splat_0() -> Simd where LaneCount: SupportedLaneCount, { >::splat_zero() } #[inline(always)] pub fn splat_n(n: u8) -> Simd where LaneCount: SupportedLaneCount, { >::splat_n(n) }