fast-hex/src/simd.rs

90 lines
3.6 KiB
Rust
Raw Normal View History

2022-10-30 15:35:17 -04:00
use std::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::util::cast;
pub trait SimdSplatZero<const LANES: usize> {
fn splat_zero() -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount;
}
pub trait SimdSplatN<const LANES: usize> {
fn splat_n(n: u8) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount;
}
pub struct SimdOps;
const W_128: usize = 128 / 8;
const W_256: usize = 256 / 8;
const W_512: usize = 512 / 8;
macro_rules! specialized {
($LANES:ident, $trait:ident {
$(
fn $name:ident($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)*
}) => {
impl<const $LANES: usize> $trait<$LANES> for SimdOps {
$(
#[inline(always)]
fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
}
}
)*
}
};
}
specialized!(LANES, SimdSplatN {
fn splat_n(n: u8) -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
_ => Simd::splat(n),
}
});
specialized!(LANES, SimdSplatZero {
fn splat_zero() -> Simd<u8, LANES> where [LaneCount<LANES>: SupportedLaneCount] {
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => <Self as SimdSplatN<LANES>>::splat_n(0),
}
});
#[inline(always)]
pub fn splat_0<const LANES: usize>() -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
<SimdOps as SimdSplatZero<LANES>>::splat_zero()
}
#[inline(always)]
pub fn splat_n<const LANES: usize>(n: u8) -> Simd<u8, LANES>
where
LaneCount<LANES>: SupportedLaneCount,
{
<SimdOps as SimdSplatN<LANES>>::splat_n(n)
}