From ed40b974bea342c2355c3f90f5a62d12b6ac22f6 Mon Sep 17 00:00:00 2001 From: Michael Pfaff Date: Thu, 15 Dec 2022 07:57:18 -0500 Subject: [PATCH] Initial commit --- .gitignore | 2 + Cargo.toml | 12 + src/lib.rs | 16 + src/simd.rs | 836 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/util.rs | 255 ++++++++++++++++ 5 files changed, 1121 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/lib.rs create mode 100644 src/simd.rs create mode 100644 src/util.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..8d2f528 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "brisk" +version = "0.1.0" +edition = "2021" + +[features] +default = ["std"] +alloc = [] +std = ["alloc"] +test = [] + +[dependencies] diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6c51469 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,16 @@ +#![feature(const_eval_select)] +#![feature(const_likely)] +#![feature(const_maybe_uninit_array_assume_init)] +#![feature(const_maybe_uninit_uninit_array)] +#![feature(const_trait_impl)] +#![feature(core_intrinsics)] +#![feature(maybe_uninit_array_assume_init)] +#![feature(maybe_uninit_uninit_array)] +#![feature(portable_simd)] +#![feature(ptr_metadata)] +#![feature(stdsimd)] + +pub mod simd; + +#[macro_use] +pub mod util; diff --git a/src/simd.rs b/src/simd.rs new file mode 100644 index 0000000..2965b6e --- /dev/null +++ b/src/simd.rs @@ -0,0 +1,836 @@ +use core::intrinsics::const_eval_select; +use core::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount}; + +use crate::util; + +use util::cast; + +const W_128: usize = 128 / 8; +const W_256: usize = 256 / 8; +const W_512: usize = 512 / 8; + +/// The value which cause `vpshufb` to write 0 instead of indexing. +const SHUF_0: u8 = 0b1000_0000; + +#[cfg(target_arch = "aarch64")] +pub use core::arch::aarch64 as arch; +#[cfg(target_arch = "arm")] +pub use core::arch::arm as arch; +#[cfg(target_arch = "wasm32")] +pub use core::arch::wasm32 as arch; +#[cfg(target_arch = "wasm64")] +pub use core::arch::wasm64 as arch; +#[cfg(target_arch = "x86")] +pub use core::arch::x86 as arch; +#[cfg(target_arch = "x86_64")] +pub use core::arch::x86_64 as arch; + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub use arch::{__m128i, __m256i, __m512i}; + +// use the maximum batch size that would be supported by AVX-512 +//pub(crate) const SIMD_WIDTH: usize = 512; +pub const SIMD_WIDTH: usize = 256; + +/// The batch size used for the "wide" decoded hex bytes (any bit in the upper half indicates an error). +pub const WIDE_BATCH_SIZE: usize = SIMD_WIDTH / 16; + +/// The batch size used for the hex digits. +pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2; + +pub const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4; + +#[macro_export] +macro_rules! __if_trace_simd { + ($( $tt:tt )*) => { + // disabled + //{ $( $tt )* } + }; +} + +pub use __if_trace_simd as if_trace_simd; + +pub trait IsSimd { + type Lane; + + const LANES: usize; +} + +impl IsSimd for Simd +where + LaneCount: SupportedLaneCount, + T: SimdElement, +{ + type Lane = T; + + const LANES: usize = LANES; +} + +macro_rules! specialized { + ($( + $vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? { + $( + $width:pat_param $( if $cfg:meta )? => $impl:expr + ),+ + $(,)? + } + )+) => {$( + #[allow(dead_code)] + #[inline(always)] + $vis fn $name($( $args )*) -> $rt $( where $( $where )* )? { + // abusing const generics to specialize without the unsoundness of real specialization! + match $LANES { + $( + $( #[cfg( $cfg )] )? + $width => $impl + ),+ + } + } + )+}; + ($LANES:ident => + $( + fn $name:ident$(<[$( $generics:tt )+]>)?($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? { + $( + $width:pat_param $( if $cfg:meta )? => $impl:expr + ),+ + $(,)? + } + )+ + ) => { + $( + #[inline(always)] + fn $name$(<$( $generics )+>)?($( $args )*) -> $rt $( where $( $where )* )? { + // abusing const generics to specialize without the unsoundness of real specialization! + match $LANES { + $( + $( #[cfg( $cfg )] )? + $width => $impl + ),+ + } + } + )+ + }; + ($trait:ident for $ty:ty; + $( + fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $args:tt )*) -> $rt:ty $(where [ $( $where:tt )* ])? { + $( + $width:pat_param $( if $cfg:meta )? => $impl:expr + ),+ + $(,)? + } + )+ + ) => { + impl $trait for $ty {$( + specialized! { LANES => + fn $name$(<[$( $generics:tt )+]>)?($( $args )*) -> $rt $(where [$( $where )*])? { + $( + $width $( if $cfg )? => $impl + ),+ + } + } + )+} + }; + } + +macro_rules! set1_short { + ($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{ + // WOW. this is 12% faster than broadcast on xmm. Seems not so much on ymm (more expensive + // array init?). + const O_LANES: usize = core::mem::size_of::<$vec>() / core::mem::size_of::<$n_ty>(); + util::cast::<_, $vec>(Simd::<$n_ty, O_LANES>::from_array([$n; O_LANES])) + //let out: $vec; + //core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, __m128i>($n), options(pure, nomem, preserves_flags, nostack)); + //out + }}; +} + +macro_rules! set1_long { + ($inst:ident, $vec:ident, $reg:ident, $n:ident: $n_ty:ty) => {{ + fn runtime(n: $n_ty) -> $vec { + unsafe { + //const O_LANES: usize = core::mem::size_of::<$vec>() / core::mem::size_of::<$n_ty>(); + //util::cast::<_, $vec>(Simd::<$n_ty, O_LANES>::from_array([$n; O_LANES])) + let out: $vec; + core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, __m128i>(n), options(pure, nomem, preserves_flags, nostack)); + out + } + } + const fn compiletime(n: $n_ty) -> $vec { + const O_LANES: usize = core::mem::size_of::<$vec>() / core::mem::size_of::<$n_ty>(); + unsafe { util::cast::<_, $vec>(Simd::<$n_ty, O_LANES>::from_array([n; O_LANES])) } + } + const_eval_select(($n,), compiletime, runtime) + }}; +} + +pub trait DoubleWidth { + type Output; +} + +macro_rules! impl_double_width { + ($($in:ty => $out:ty),+) => { + $( + impl DoubleWidth for $in { + type Output = $out; + } + )+ + } +} + +impl_double_width!(u8 => u16, u16 => u32, u32 => u64); + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +impl_double_width!(u64 => arch::__m128i, arch::__m128i => arch::__m256i, arch::__m256i => arch::__m512i); + +#[const_trait] +pub trait SimdSplat { + type Output where LaneCount: SupportedLaneCount; + + fn splat(self) -> Self::Output where LaneCount: SupportedLaneCount; + + fn splat_zero() -> Self::Output where LaneCount: SupportedLaneCount; +} + +pub trait SimdLoad: Sized { + fn load_128(self) -> arch::__m128i; + + #[inline(always)] + fn load_256(self) -> arch::__m256i { + unsafe { util::cast(self.load_128()) } + } + + #[inline(always)] + fn load_512(self) -> arch::__m512i { + unsafe { util::cast(self.load_128()) } + } +} + +macro_rules! impl_load { + (($in:ident, $($in_fmt:ident)?) $in_v:ident) => {{ + unsafe { + let out: _; + core::arch::asm!(concat!("vmovq {}, {:", $(stringify!($in_fmt), )? "}"), lateout(xmm_reg) out, in($in) $in_v, options(pure, nomem, preserves_flags, nostack)); + out + } + }} +} + +macro_rules! simd_load_fallback { + ($ty:ty, $vec_ty:ty, $self:ident) => {{ + let mut a = core::mem::MaybeUninit::uninit_array(); + a[0] = core::mem::MaybeUninit::new($self); + for i in 1..(core::mem::size_of::<$vec_ty>() / core::mem::size_of::<$ty>()) { + a[i] = core::mem::MaybeUninit::new(0); + } + let a = unsafe { core::mem::MaybeUninit::array_assume_init(a) }; + Simd::from_array(a).into() + }}; +} + +macro_rules! impl_ops { + ($( + $ty:ty { + /// The appropriate register type. + reg: $reg:ident, + reg_fmt: $($reg_fmt:ident)?, + broadcast: $broadcast:ident, + set1: [$set1_128:ident, $set1_256:ident, $set1_512:ident] + $(,)? + } + )+) => {$( + impl SimdLoad for $ty { + #[inline(always)] + fn load_128(self) -> arch::__m128i { + #[cfg(target_feature = "avx")] + { impl_load!(($reg, $($reg_fmt)?) self) } + #[cfg(not(target_feature = "avx"))] + simd_load_fallback!($ty, arch::__m128i, self) + } + } + + impl const SimdSplat for $ty { + type Output = Simd<$ty, LANES> where LaneCount: SupportedLaneCount; + + specialized! { + fn splat(self) -> Self::Output where [LaneCount: SupportedLaneCount] { + W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_short!($broadcast, __m128i, xmm_reg, self: $ty)) }, + W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx2") => unsafe { cast(set1_long!($broadcast, __m256i, ymm_reg, self: $ty)) }, + W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512bw") => unsafe { cast(set1_long!($broadcast, __m512i, zmm_reg, self: $ty)) }, + + // these are *terrible*. They compile to a bunch of MOVs and SETs + W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(arch::$set1_128(self as i8)) }, + W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(arch::$set1_256(self as i8)) }, + + // I can't actually test these, but they're documented as doing either a broadcast or the terrible approach mentioned above. + W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512f") => unsafe { cast(arch::$set1_512(self as i8)) }, + _ => { + #[inline(always)] + const fn compiletime(_: $ty) -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + panic!("unsupported compile-time splat"); + } + #[inline(always)] + fn runtime(v: $ty) -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + Simd::splat(v) + } + unsafe { const_eval_select((self,), compiletime, runtime) } + }, + } + + fn splat_zero() -> Self::Output where [LaneCount: SupportedLaneCount] { + W_128 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "sse2") => { + #[inline(always)] + const fn compiletime() -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + <$ty>::splat(0) + } + #[inline(always)] + fn runtime() -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + unsafe { cast(arch::_mm_setzero_si128()) } + } + unsafe { const_eval_select((), compiletime, runtime) } + }, + W_256 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx") => { + #[inline(always)] + const fn compiletime() -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + <$ty>::splat(0) + } + #[inline(always)] + fn runtime() -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + unsafe { cast(arch::_mm256_setzero_si256()) } + } + unsafe { const_eval_select((), compiletime, runtime) } + }, + W_512 if all(any(target_arch = "x86", target_arch = "x86_64"), target_feature = "avx512f") => { + #[inline(always)] + const fn compiletime() -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + <$ty>::splat(0) + } + #[inline(always)] + fn runtime() -> Simd<$ty, LANES> + where + LaneCount: SupportedLaneCount, + { + unsafe { cast(arch::_mm512_setzero_si512()) } + } + unsafe { const_eval_select((), compiletime, runtime) } + }, + _ => Self::splat(0), + } + } + } + )+}; +} + +impl_ops! { + u8 { + reg: reg_byte, + reg_fmt: , + broadcast: vpbroadcastb, + set1: [_mm_set1_epi8, _mm256_set1_epi8, _mm512_set1_epi8] + } + + u16 { + reg: reg, + reg_fmt: e, + broadcast: vpbroadcastw, + set1: [_mm_set1_epi16, _mm256_set1_epi16, _mm512_set1_epi16] + } + + u32 { + reg: reg, + reg_fmt: e, + broadcast: vpbroadcastd, + set1: [_mm_set1_epi32, _mm256_set1_epi32, _mm512_set1_epi32] + } + + u64 { + reg: reg, + reg_fmt: r, + broadcast: vpbroadcastq, + set1: [_mm_set1_epi64, _mm256_set1_epi64, _mm512_set1_epi64] + } +} + +/// Defines the indices used by [`swizzle`]. +#[macro_export] +macro_rules! __swizzle_indices { + ($name:ident = [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { + core::arch::global_asm!(concat!(".", stringify!($name), ":") +$( , concat!("\n .byte ", stringify!($index)) )+ +$( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )?); + }; +} + +#[macro_export] +macro_rules! __swizzle { + /*(xmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { + $crate::simd::swizzle!(@ xmm_reg, $src, $dest, (xmmword) [$( $index ),+] $( , [$( $padding )+] )?) + }; + (ymm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { + $crate::simd::swizzle!(@ ymm_reg, $src, $dest, (ymmword) [$( $index ),+] $( , [$( $padding )+] )?) + }; + (zmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { + $crate::simd::swizzle!(@ zmm_reg, $src, $dest, (zmmword) [$( $index ),+] $( , [$( $padding )+] )?) + };*/ + (xmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { + $crate::simd::swizzle!(@ xmm_reg, x, $src, $mode $dest, (xmmword) $indices) + }; + (ymm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { + $crate::simd::swizzle!(@ ymm_reg, y, $src, $mode $dest, (ymmword) $indices) + }; + (zmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => { + $crate::simd::swizzle!(@ zmm_reg z, $src, $mode $dest, (zmmword) $indices) + }; + ($reg:ident, $src:expr, $dest:expr, mem $indices:expr) => { + core::arch::asm!("vpshufb {}, {}, [{}]", in($reg) $src, lateout($reg) $dest, in(reg) $indices, options(readonly, preserves_flags, nostack)); + }; + ($reg:ident, $src:expr, $dest:expr, data($indices_reg:ident) $indices:expr) => { + core::arch::asm!("vpshufb {}, {}, {}", in($reg) $src, lateout($reg) $dest, in($indices_reg) $indices, options(pure, nomem, preserves_flags, nostack)); + }; + //(@ $reg:ident, $src:expr, $dest:expr, ($indices_reg:ident) [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { + (@ $reg:ident, $token:ident, $src:expr, $mode:ident $dest:expr, ($indices_reg:ident) $indices:ident) => { + core::arch::asm!(concat!("vpshufb {:", stringify!($token), "}, {:", stringify!($token), "}, ", stringify!($indices_reg), " ptr [rip + .", stringify!($indices), "]"), $mode($reg) $dest, in($reg) $src, options(pure, nomem, preserves_flags, nostack)); +// core::arch::asm!("2:" +// $( , concat!("\n .byte ", stringify!($index)) )+ +// $( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )? +// , "\n3:\n", concat!(" vpshufb {}, {}, ", stringify!($indices_reg), " ptr [rip + 2b]"), in($reg) $src, lateout($reg) $dest) + }; +// ($src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => { +// $crate::simd::swizzle!(@ $src, $dest, [$( stringify!($index) ),+] $( , [$( "\n ", subst!($padding, ""), "zero 1" )+] )?) +// }; +// (@ $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:literal )+] )?) => { +// core::arch::asm!(r#" +// .indices:"#, +// $( "\n .byte ", $index ),+ +// $( $( $padding ),+ )? +// r#" +// lsb: +// vpshufb {}, {}, xmmword ptr [rip + .indices] +// "#, in(xmm_reg) $src, lateout(xmm_reg) $dest) +// }; + } + +pub use __swizzle as swizzle; +pub use __swizzle_indices as swizzle_indices; + +/// Merges the low halves of `a` and `b` into a single register like `ab`. +#[inline(always)] +pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { + unsafe { + // xmm0 = xmm1[0],xmm0[0] + let out: _; + core::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack)); + out + } +} + +/// The args are in little endian order (first arg is lowest order) +#[inline(always)] +pub fn merge_m128_m256(a: arch::__m128i, b: arch::__m128i) -> arch::__m256i { + unsafe { + let out: _; + core::arch::asm!("vinserti128 {}, {:y}, {}, 0x1", lateout(ymm_reg) out, in(ymm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack)); + out + } +} + +#[inline(always)] +pub fn extract_hi_half(v: arch::__m256i) -> arch::__m128i { + unsafe { + arch::_mm256_extracti128_si256(v, 1) + } +} + +macro_rules! extract_lohi_bytes { + (($mask:expr, $op12:ident, $op3:ident), $in:ident) => {{ + const MASK: arch::__m128i = unsafe { core::mem::transmute($mask) }; + unsafe { + let out: _; + core::arch::asm!( + //concat!("vmovdqa {mask}, xmmword ptr [rip + .", stringify!($mask), "]"), + "vextracti128 {inter}, {input:y}, 1", + concat!(stringify!($op12), " {inter}, {inter}, {mask}"), + concat!(stringify!($op12), " {output:x}, {input:x}, {mask}"), + concat!(stringify!($op3), " {output:x}, {output:x}, {inter}"), + mask = in(xmm_reg) MASK, input = in(ymm_reg) $in, output = lateout(xmm_reg) out, inter = out(xmm_reg) _, + options(pure, nomem, preserves_flags, nostack) + ); + out + } + }}; +} + +#[inline(always)] +pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i { + extract_lohi_bytes!(([0x00ffu16; 8], vpand, vpackuswb), v) +} + +#[inline(always)] +pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i { + extract_lohi_bytes!( + ( + [0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0], + vpshufb, + vpunpcklqdq + ), + v + ) +} + +pub trait SimdTestAnd { + /// Returns true if the result of the bitwise AND of `self` and `mask` is not all zero. + fn test_and_non_zero(self, mask: Self) -> bool; +} + +pub trait SimdBitwise { + /// Returns the bitwise OR of `self` and `rhs`. + fn or(self, rhs: Self) -> Self; + + /// Returns the bitwise AND of `self` and `rhs`. + fn and(self, rhs: Self) -> Self; +} + +pub trait SimdWiden { + /// Widens the lower bytes by spacing and zero-extending them to N bytes. + fn widen(self) -> Self; +} + +#[cfg(target_feature = "avx")] +impl SimdTestAnd for arch::__m128i { + #[inline(always)] + fn test_and_non_zero(self, mask: Self) -> bool { + unsafe { + let out: u8; + core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = lateout(reg_byte) out, options(pure, nomem, nostack)); + core::mem::transmute(out) + } + } +} + +#[cfg(target_feature = "avx")] +impl SimdTestAnd for arch::__m256i { + #[inline(always)] + fn test_and_non_zero(self, mask: Self) -> bool { + unsafe { + let out: u8; + core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = lateout(reg_byte) out, options(pure, nomem, nostack)); + core::mem::transmute(out) + } + } +} + +const USE_BITWISE_INTRINSICS: bool = true; + +#[cfg(target_feature = "avx")] +impl SimdBitwise for arch::__m128i { + #[inline(always)] + fn or(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm_or_si128(self, rhs) + } else { + let out: _; + core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = lateout(xmm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } + + #[inline(always)] + fn and(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm_and_si128(self, rhs) + } else { + let out: _; + core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = lateout(xmm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } +} + +#[cfg(target_feature = "avx2")] +impl SimdBitwise for arch::__m256i { + #[inline(always)] + fn or(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm256_or_si256(self, rhs) + } else { + let out: _; + core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = lateout(ymm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } + + #[inline(always)] + fn and(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm256_and_si256(self, rhs) + } else { + let out: _; + core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = lateout(ymm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } +} + +macro_rules! widen_128_impl { + ($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:literal, $offset:literal) => {unsafe { + const LEN: usize = core::mem::size_of::<$Self>(); + const MASK_MOD_EQ: usize = if $be { 0 } else { $shift }; + // TODO: evaluate whether there is a more efficient approach for the ignored bytes. + const INDICES: [u8; LEN] = array_op!(gen[LEN] |i| { + if (i + 1) % ($shift + 1) == MASK_MOD_EQ { + ((i as u8) >> $shift) + $offset + } else { + SHUF_0 + } + }); + const INDICES_SIMD: $Self = unsafe { util::cast(INDICES) }; + //const MASK: [u8; LEN] = array_op!(gen[LEN] |i| { + // if (i + 1) % ($shift + 1) == MASK_MOD_EQ { 0xff } else { 0x00 } + //}); + //const MASK: [u8; LEN] = array_op!(gen[LEN] |i| if ((i + 1) % ($shift + 1) == $offset) { 0xff } else { 0x00 }); + let out: _; + core::arch::asm!("vpshufb {out}, {in}, {indices}", in = in($reg) $self, indices = in($reg) INDICES_SIMD, out = lateout($reg) out, options(pure, nomem, preserves_flags, nostack)); + if_trace_simd! { + println!("Offset: {}", $offset); + println!("Indices: {:?}", INDICES); + //println!("MASK: {MASK:?}"); + } + out + //SimdBitwise::and(out, util::cast(MASK)) + }}; + ($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset_in:ident [$( $offset:literal ),+]) => { + match $offset_in { + $( $offset => if $be { widen_128_impl!($self, $Self, $reg, $shift, true, $offset) } else { widen_128_impl!($self, $Self, $reg, $shift, false, $offset) }, )+ + _ => panic!("Unsupported widen O value"), + } + }; + ($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset:ident) => { + widen_128_impl!($self, $Self, $reg, $shift, $be, $offset [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + }; +} + +macro_rules! widen_256_impl { + ($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:literal, $offset:literal) => {unsafe { + const LEN: usize = core::mem::size_of::<$Self>(); + const MASK_MOD_EQ: usize = if $be { 0 } else { $shift }; + // TODO: evaluate whether there is a more efficient approach for the ignored bytes. + const INDICES_LO: [u8; LEN] = array_op!(gen[LEN] |i| { + let i = i % LEN; + if (i + 1) % ($shift + 1) == MASK_MOD_EQ { + ((i as u8) >> $shift) + $offset + } else { + SHUF_0 + } + }); + const INDICES_HI: [u8; LEN] = array_op!(gen[LEN] |i| { + let i = (i % (LEN / 2)) + LEN / 2; + if (i + 1) % ($shift + 1) == MASK_MOD_EQ { + ((i as u8) >> $shift) + $offset + } else { + SHUF_0 + } + }); + const INDICES_LO_SIMD: $Self = unsafe { util::cast(INDICES_LO) }; + const INDICES_HI_SIMD: $Self = unsafe { util::cast(INDICES_HI) }; + const HI_MASK: [u8; LEN] = array_op!(gen[LEN] |i| if i < LEN / 2 { 0x00 } else { 0xff }); + const HI_MASK_SIMD: $Self = unsafe { util::cast(HI_MASK) }; + //const MASK: [u8; LEN] = array_op!(gen[LEN] |i| { + // if (i + 1) % ($shift + 1) == MASK_MOD_EQ { 0xff } else { 0x00 } + //}); + //const MASK: [u8; LEN] = array_op!(gen[LEN] |i| if ((i + 1) % ($shift + 1) == $offset) { 0xff } else { 0x00 }); + let lo: $Self; + core::arch::asm!("vpshufb {out:x}, {in:x}, {indices:x}", in = in($reg) $self, indices = in($reg) INDICES_LO_SIMD, out = lateout($reg) lo, options(pure, nomem, preserves_flags, nostack)); + let hi: $Self; + core::arch::asm!("vpshufb {out:x}, {in:x}, {indices:x}", in = in($reg) $self, indices = in($reg) INDICES_HI_SIMD, out = lateout($reg) hi, options(pure, nomem, preserves_flags, nostack)); + let out = lo.or(hi.and(HI_MASK_SIMD)); + if_trace_simd! { + println!("Offset: {}", $offset); + println!("Indices: {:?},{:?}", INDICES_LO, INDICES_HI); + //println!("MASK: {MASK:?}"); + } + out + //SimdBitwise::and(out, util::cast(MASK)) + }}; + ($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset_in:ident [$( $offset:literal ),+]) => { + match $offset_in { + $( $offset => if $be { widen_256_impl!($self, $Self, $reg, $shift, true, $offset) } else { widen_256_impl!($self, $Self, $reg, $shift, false, $offset) }, )+ + _ => panic!("Unsupported widen O value"), + } + }; + ($self:ident, $Self:ty, $reg:ident, $shift:literal, $be:ident, $offset:ident) => { + widen_256_impl!($self, $Self, $reg, $shift, $be, $offset [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]) + }; +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __simd__shr { + ($inst:ident, $n:literal, ($in_reg:ident) $in:expr) => {{ + let out: _; + core::arch::asm!(concat!(stringify!($inst), " {dst}, {src}, ", $n), src = in($in_reg) $in, dst = lateout($in_reg) out); + out + }}; + (16, $n:literal, ($in_reg:ident) $in:expr) => { + $crate::simd::shr!(vpsrlw, $n, ($in_reg) $in) + }; + (32, $n:literal, ($in_reg:ident) $in:expr) => { + $crate::simd::shr!(vpsrld, $n, ($in_reg) $in) + }; + (64, $n:literal, ($in_reg:ident) $in:expr) => { + $crate::simd::shr!(vpsrlq, $n, ($in_reg) $in) + }; +} + +pub use __simd__shr as shr; + +/*impl SimdWiden for arch::__m128i { + #[inline(always)] + fn widen(self) -> Self { + match N { + 1 => self, + 2 => widen_128_impl!(self, arch::__m128i, xmm_reg, 1, BE, O), + 4 => widen_128_impl!(self, arch::__m128i, xmm_reg, 2, BE, O), + 8 => widen_128_impl!(self, arch::__m128i, xmm_reg, 3, BE, O), + _ => panic!("Unsupported widen N value"), + } + } +} + +impl SimdWiden for arch::__m256i { + #[inline(always)] + fn widen(self) -> Self { + match N { + 1 => self, + 2 => widen_256_impl!(self, arch::__m256i, ymm_reg, 1, BE, O), + 4 => widen_256_impl!(self, arch::__m256i, ymm_reg, 2, BE, O), + 8 => widen_256_impl!(self, arch::__m256i, ymm_reg, 3, BE, O), + 16 => widen_256_impl!(self, arch::__m256i, ymm_reg, 4, BE, O), + _ => panic!("Unsupported widen N value"), + } + } +}*/ + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_splat() { + const EXPECTED: [u8; 32] = [3u8; 32]; + const ACTUAL_CONST: Simd = 3u8.splat::<32>(); + const EQ_CONST: bool = { + let mut i = 0; + loop { + if i == 32 { + break true; + } + if ACTUAL_CONST.to_array()[i] != EXPECTED[i] { + break false; + } + i += 1; + } + }; + + let actual = 3u8.splat::<32>(); + assert_eq!(Simd::from(actual), Simd::from_array(EXPECTED)); + assert!(EQ_CONST, "constant evaluated splat did not produce the expected result"); + } + + #[test] + fn test_splat_zero() { + const EXPECTED: [u8; 32] = [0u8; 32]; + const ACTUAL_CONST: Simd = u8::splat_zero::<32>(); + const EQ_CONST: bool = { + let mut i = 0; + loop { + if i == 32 { + break true; + } + if ACTUAL_CONST.to_array()[i] != EXPECTED[i] { + break false; + } + i += 1; + } + }; + + let actual = u8::splat_zero::<32>(); + assert_eq!(Simd::from(actual), Simd::from_array(EXPECTED)); + assert!(EQ_CONST, "constant evaluated splat_zero did not produce the expected result"); + } + + /*#[test] + fn test_widen() { + const SAMPLE_IN: [u8; 32] = [ + 1, 2, 3, 4, 5, 6, 7, 8, + 1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 9, 10, 11, 12, 13, 14, 15, 16 + ]; + + const SAMPLE_OUT_LO_BE: [u8; 32] = [ + 0, 1, 0, 2, 0, 3, 0, 4, + 0, 5, 0, 6, 0, 7, 0, 8, + 0, 1, 0, 2, 0, 3, 0, 4, + 0, 5, 0, 6, 0, 7, 0, 8 + ]; + const SAMPLE_OUT_HI_BE: [u8; 32] = [ + 0, 9, 0, 10, 0, 11, 0, 12, + 0, 13, 0, 14, 0, 15, 0, 16, + 0, 9, 0, 10, 0, 11, 0, 12, + 0, 13, 0, 14, 0, 15, 0, 16 + ]; + + let sample: arch::__m256i = Simd::from_array(SAMPLE_IN).into(); + let widened = sample.widen::<2, 0, true>(); + //assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_LO_BE)); + let widened = sample.widen::<2, 16, true>(); + assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_HI_BE)); + + const SAMPLE_OUT_LO_LE: [u8; 32] = [ + 1, 0, 2, 0, 3, 0, 4, 0, + 5, 0, 6, 0, 7, 0, 8, 0, + 1, 0, 2, 0, 3, 0, 4, 0, + 5, 0, 6, 0, 7, 0, 8, 0 + ]; + const SAMPLE_OUT_HI_LE: [u8; 32] = [ + 9, 0, 10, 0, 11, 0, 12, 0, + 13, 0, 14, 0, 15, 0, 16, 0, + 9, 0, 10, 0, 11, 0, 12, 0, + 13, 0, 14, 0, 15, 0, 16, 0 + ]; + + let sample: arch::__m256i = Simd::from_array(SAMPLE_IN).into(); + let widened = sample.widen::<2, 0, false>(); + //assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_LO_LE)); + let widened = sample.widen::<2, 16, false>(); + assert_eq!(Simd::from(widened), Simd::from_array(SAMPLE_OUT_HI_LE)); + }*/ +} diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 0000000..30b13f7 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,255 @@ +use std::num::NonZeroUsize; + +/// Compile-time array operations. +/// +/// # Generation +/// +/// ```rust +/// const I_TIMES_2: [usize; 32] = array_op!(gen[32] |i| i * 2); +/// ``` +/// +/// # Mapping +/// +/// ```rust +/// const I_TIMES_2_PLUS_1: [usize; 32] = array_op!(map[32, I_TIMES_2] |i, x| x + 1); +/// ``` +#[macro_export] +macro_rules! array_op { + (gen[$len:expr] |$i:pat_param| $val:expr) => {{ + let mut out = ::core::mem::MaybeUninit::uninit_array::<$len>(); + let mut i = 0; + while i < $len { + out[i] = ::core::mem::MaybeUninit::new(match i { + $i => $val, + }); + i += 1; + } + #[allow(unused_unsafe)] + unsafe { + ::core::mem::MaybeUninit::array_assume_init(out) + } + }}; + (map[$len:expr, $src:expr] |$i:pat_param, $s:pat_param| $val:expr) => {{ + $crate::array_op!( + gen[$len] + | i + | match i { + $i => match $src[i] { + $s => $val, + }, + } + ) + }}; +} + +#[inline] +pub const fn cast_u8_u32(arr: [u8; N]) -> [u32; N] { + array_op!(map[N, arr] |_, v| v as u32) +} + +#[macro_export] +macro_rules! defer_impl { + ( + => $impl:ident; + + $( fn $name:ident($( $pname:ident: $pty:ty ),*) -> $rty:ty; )* + ) => { + $( + #[inline(always)] + fn $name($( $pname: $pty ),*) -> $rty { + <$impl>::$name($( $pname ),*) + } + )* + }; +} + +#[inline(always)] +#[cold] +pub const fn cold() {} + +#[inline(always)] +pub const fn likely(b: bool) -> bool { + core::intrinsics::likely(b) +} + +#[inline(always)] +pub const fn unlikely(b: bool) -> bool { + core::intrinsics::unlikely(b) +} + +/// Like transmute, but implemented via a union so that we can use it in situations where +/// transmute's "safety" restrictions are too strict or uninformed (i.e. we can prove it is safe +/// or we simply don't care). +#[inline(always)] +pub const unsafe fn cast(a: A) -> B { + union Cast { + a: core::mem::ManuallyDrop, + b: core::mem::ManuallyDrop, + } + + core::mem::ManuallyDrop::into_inner( + Cast { + a: core::mem::ManuallyDrop::new(a), + } + .b, + ) +} + +#[doc(hidden)] +#[macro_export] +macro_rules! __subst { + ([$( $ignore:tt )*], [$( $use:tt )*]) => { + $( $use )* + }; + } + +pub use __subst as subst; + +#[inline(always)] +pub const fn align_down_to(n: usize) -> usize { + let shift = match N.checked_ilog2() { + Some(x) => x, + None => 0, + }; + return n >> shift << shift; +} + +#[inline(always)] +pub const fn align_up_to(n: usize) -> usize { + let shift = match N.checked_ilog2() { + Some(x) => x, + None => 0, + }; + return (n + (N - 1)) >> shift << shift; +} + +/// This function is unsafe because the caller must ensure that the value is not used until it has +/// been properly initialized. +#[cfg(feature = "alloc")] +#[inline(always)] +pub unsafe fn alloc_aligned_box() -> Box { + match NonZeroUsize::new(core::mem::size_of::()) { + Some(size) => { + let ptr = alloc_aligned( + size, + core::cmp::max(core::mem::align_of::(), ALIGN), + ); + Box::from_raw(ptr as *mut T) + } + None => Box::from_raw(core::ptr::NonNull::dangling().as_ptr()), + } +} + +/// This function is unsafe because the caller must ensure that the value is not used until it has +/// been properly initialized. +#[cfg(feature = "alloc")] +#[inline(always)] +pub unsafe fn alloc_aligned_box_slice(len: usize) -> Box<[T]> { + if core::mem::size_of::() == 0 || len == 0 { + return Box::from_raw(core::ptr::NonNull::from_raw_parts(core::ptr::NonNull::dangling(), 0).as_ptr()); + } + let size = core::cmp::max(core::mem::size_of::(), core::mem::align_of::()); + let ptr = alloc_aligned(NonZeroUsize::new_unchecked(size * len), core::cmp::max(core::mem::align_of::(), ALIGN)); + Box::from_raw(core::ptr::slice_from_raw_parts_mut(ptr as *mut T, len)) +} + +/// This function is unsafe because all the preconditions of +/// `std::alloc::Layout::from_size_align_unchecked` must be upheld by the caller. +#[cfg(feature = "alloc")] +#[inline(always)] +pub unsafe fn alloc_aligned(len: NonZeroUsize, align: usize) -> *mut u8 { + unsafe { + // SAFETY: len is nonzero + let layout = std::alloc::Layout::from_size_align_unchecked(len.get(), align); + let ptr = std::alloc::alloc(layout); + if ptr.is_null() { + std::alloc::handle_alloc_error(layout); + } + ptr + } +} + +#[macro_export] +macro_rules! unroll { + (let [$( $x:ident ),+] => |$y:pat_param| $expr:expr) => { + $crate::unroll!(let [$( $x: ($x) ),+] => |$y| $expr); + }; + (let [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => { + $crate::unroll!(@ [$] [$( $id: ($( $x ),+) ),+] => |$( $y ),+| $($expr)+) + }; + (@ [$dollar:tt] [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => { + macro_rules! __unrolled { + ($dollar id:ident, $dollar ( $dollar z:tt ),+) => { + #[allow(unused_parens)] + let ($( $y ),+) = ($dollar ( $dollar z ),+); + let $dollar id = $($expr)+; + }; + } + $( __unrolled!($id, $( $x ),+); )+ + //$( + // let ($( $y ),+) = ($( $x ),+); + // $expr; + //)+ + }; + ([$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => { + $crate::unroll!(@ [$] [$( ($( $x ),+) ),+] => |$( $y ),+| $($expr)+) + }; + (@ [$dollar:tt] [$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => { + macro_rules! __unrolled { + ($dollar ( $dollar z:tt ),+) => { + #[allow(unused_parens)] + let ($( $y ),+) = ($dollar ( $dollar z ),+); + $($expr)+; + }; + } + $( __unrolled!($( $x ),+); )+ + //$( + // let ($( $y ),+) = ($( $x ),+); + // $expr; + //)+ + }; +} + +#[cfg(test)] +mod test { + use super::*; + use crate::array_op; + + #[test] + pub fn test_align_down_to() { + assert_eq!(align_down_to::<8>(8), 8); + assert_eq!(align_down_to::<16>(8), 0); + assert_eq!(align_down_to::<16>(16), 16); + assert_eq!(align_down_to::<16>(15), 0); + assert_eq!(align_down_to::<16>(17), 16); + } + + #[test] + pub fn test_array_op_gen() { + assert_eq!(array_op!(gen[4] | i | i), [0, 1, 2, 3]); + assert_eq!( + array_op!(gen[16] | i | i), + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + ); + assert_eq!(array_op!(gen[4] | i | i as u8), [0u8, 1, 2, 3]); + assert_eq!( + array_op!(gen[16] | i | i as u8), + [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + ); + + const A: [u8; 4] = array_op!(gen[4] | i | i as u8); + const B: [u8; 16] = array_op!(gen[16] | i | i as u8); + assert_eq!(A, [0u8, 1, 2, 3]); + assert_eq!(B, [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]); + + const LEN: usize = core::mem::size_of::(); + const INDICES: [u8; LEN] = array_op!(gen[LEN] | i | (i as u8) >> 1); + assert_eq!( + INDICES, + [ + 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, + 13, 13, 14, 14, 15, 15 + ] + ); + } +}