fast-hex/src/simd.rs

396 lines
18 KiB
Rust
Raw Normal View History

use core::simd::{LaneCount, Simd, SupportedLaneCount, SimdElement};
2022-10-30 15:35:17 -04:00
use crate::util::cast;
const W_128: usize = 128 / 8;
const W_256: usize = 256 / 8;
const W_512: usize = 512 / 8;
2022-10-31 22:42:01 -04:00
#[cfg(target_arch = "aarch64")]
2022-11-01 00:50:30 -04:00
pub use core::arch::aarch64 as arch;
2022-10-31 22:42:01 -04:00
#[cfg(target_arch = "arm")]
2022-11-01 00:50:30 -04:00
pub use core::arch::arm as arch;
2022-10-31 22:42:01 -04:00
#[cfg(target_arch = "wasm32")]
2022-11-01 00:50:30 -04:00
pub use core::arch::wasm32 as arch;
2022-10-31 22:42:01 -04:00
#[cfg(target_arch = "wasm64")]
2022-11-01 00:50:30 -04:00
pub use core::arch::wasm64 as arch;
2022-10-31 22:42:01 -04:00
#[cfg(target_arch = "x86")]
2022-11-01 00:50:30 -04:00
pub use core::arch::x86 as arch;
2022-10-31 22:42:01 -04:00
#[cfg(target_arch = "x86_64")]
2022-11-01 00:50:30 -04:00
pub use core::arch::x86_64 as arch;
2022-10-31 22:42:01 -04:00
2022-11-01 19:03:28 -04:00
// use the maximum batch size that would be supported by AVX-512
//pub(crate) const SIMD_WIDTH: usize = 512;
pub const SIMD_WIDTH: usize = 256;
#[macro_export]
macro_rules! __if_trace_simd {
($( $tt:tt )*) => {
// disabled
//{ $( $tt )* }
};
}
pub use __if_trace_simd as if_trace_simd;
pub trait IsSimd {
type Lane;
const LANES: usize;
}
impl<T, const LANES: usize> IsSimd for Simd<T, LANES> where LaneCount<LANES>: SupportedLaneCount, T: SimdElement {
type Lane = T;
const LANES: usize = LANES;
}
2022-10-30 15:35:17 -04:00
macro_rules! specialized {
2022-10-31 22:42:01 -04:00
($(
$vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
2022-10-30 15:35:17 -04:00
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
2022-10-31 22:42:01 -04:00
)+) => {$(
#[allow(dead_code)]
2022-10-31 22:42:01 -04:00
#[inline(always)]
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
2022-10-30 15:35:17 -04:00
}
2022-10-31 22:42:01 -04:00
}
)+};
($trait:ident for $ty:ty;
$(
fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)+
) => {
impl<const $LANES: usize> $trait for $ty {$(
#[inline(always)]
fn $name$(<$( $generics )+>)?($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
}
}
)+}
};
2022-10-31 22:42:01 -04:00
}
macro_rules! set1 {
($arch:ident, $inst:ident, $vec:ident, $reg:ident, $n:ident) => {{
2022-11-01 00:50:30 -04:00
let out: core::arch::$arch::$vec;
core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack));
2022-10-31 22:42:01 -04:00
out
}};
}
specialized! {
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_n<LANES>(n: u8) -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m256i, ymm_reg, n)) },
2022-10-30 15:35:17 -04:00
2022-10-31 22:42:01 -04:00
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
2022-10-30 15:35:17 -04:00
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
_ => Simd::splat(n),
}
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_u16<LANES>(n: u16) -> Simd<u16, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi16(n as i16)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi16(n as i16)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi16(n as i16)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi16(n as i16)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi16(n as i16)) },
_ => Simd::splat(n),
}
pub fn splat_0u8<LANES>() -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => unsafe { crate::util::cast(splat_n(0)) },
}
pub fn splat_0u16<LANES>() -> Simd<u16, LANES> where [
2022-10-31 22:42:01 -04:00
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
2022-10-30 15:35:17 -04:00
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => unsafe { crate::util::cast(splat_n(0)) },
2022-10-31 22:42:01 -04:00
}
}
/// Defines the indices used by [`swizzle`].
#[macro_export]
macro_rules! __swizzle_indices {
($name:ident = [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
2022-11-01 00:50:30 -04:00
core::arch::global_asm!(concat!(".", stringify!($name), ":")
2022-10-31 22:42:01 -04:00
$( , concat!("\n .byte ", stringify!($index)) )+
$( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )?);
};
}
#[macro_export]
macro_rules! __swizzle {
/*(xmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
$crate::simd::swizzle!(@ xmm_reg, $src, $dest, (xmmword) [$( $index ),+] $( , [$( $padding )+] )?)
};
(ymm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
$crate::simd::swizzle!(@ ymm_reg, $src, $dest, (ymmword) [$( $index ),+] $( , [$( $padding )+] )?)
};
(zmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
$crate::simd::swizzle!(@ zmm_reg, $src, $dest, (zmmword) [$( $index ),+] $( , [$( $padding )+] )?)
};*/
(xmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => {
$crate::simd::swizzle!(@ xmm_reg, x, $src, $mode $dest, (xmmword) $indices)
};
(ymm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => {
$crate::simd::swizzle!(@ ymm_reg, y, $src, $mode $dest, (ymmword) $indices)
};
(zmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => {
$crate::simd::swizzle!(@ zmm_reg z, $src, $mode $dest, (zmmword) $indices)
};
($reg:ident, $src:expr, $dest:expr, mem $indices:expr) => {
2022-11-01 00:50:30 -04:00
core::arch::asm!("vpshufb {}, {}, [{}]", in($reg) $src, lateout($reg) $dest, in(reg) $indices, options(readonly, preserves_flags, nostack));
2022-10-31 22:42:01 -04:00
};
($reg:ident, $src:expr, $dest:expr, data($indices_reg:ident) $indices:expr) => {
2022-11-01 00:50:30 -04:00
core::arch::asm!("vpshufb {}, {}, {}", in($reg) $src, lateout($reg) $dest, in($indices_reg) $indices, options(pure, nomem, preserves_flags, nostack));
2022-10-31 22:42:01 -04:00
};
//(@ $reg:ident, $src:expr, $dest:expr, ($indices_reg:ident) [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
(@ $reg:ident, $token:ident, $src:expr, $mode:ident $dest:expr, ($indices_reg:ident) $indices:ident) => {
2022-11-01 00:50:30 -04:00
core::arch::asm!(concat!("vpshufb {:", stringify!($token), "}, {:", stringify!($token), "}, ", stringify!($indices_reg), " ptr [rip + .", stringify!($indices), "]"), $mode($reg) $dest, in($reg) $src, options(pure, nomem, preserves_flags, nostack));
// core::arch::asm!("2:"
2022-10-31 22:42:01 -04:00
// $( , concat!("\n .byte ", stringify!($index)) )+
// $( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )?
// , "\n3:\n", concat!(" vpshufb {}, {}, ", stringify!($indices_reg), " ptr [rip + 2b]"), in($reg) $src, lateout($reg) $dest)
};
// ($src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
// $crate::simd::swizzle!(@ $src, $dest, [$( stringify!($index) ),+] $( , [$( "\n ", subst!($padding, ""), "zero 1" )+] )?)
// };
// (@ $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:literal )+] )?) => {
2022-11-01 00:50:30 -04:00
// core::arch::asm!(r#"
2022-10-31 22:42:01 -04:00
// .indices:"#,
// $( "\n .byte ", $index ),+
// $( $( $padding ),+ )?
// r#"
// lsb:
// vpshufb {}, {}, xmmword ptr [rip + .indices]
// "#, in(xmm_reg) $src, lateout(xmm_reg) $dest)
// };
}
pub use __swizzle as swizzle;
pub use __swizzle_indices as swizzle_indices;
#[inline(always)]
pub fn load_u64_m128(v: u64) -> arch::__m128i {
unsafe {
let out: _;
2022-11-01 00:50:30 -04:00
core::arch::asm!("vmovq {}, {}", lateout(xmm_reg) out, in(reg) v, options(pure, nomem, preserves_flags, nostack));
2022-10-31 22:42:01 -04:00
out
}
}
/// Merges the low halves of `a` and `b` into a single register like `ab`.
2022-10-31 22:42:01 -04:00
#[inline(always)]
pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
2022-10-31 22:42:01 -04:00
unsafe {
// xmm0 = xmm1[0],xmm0[0]
let out: _;
2022-11-01 00:50:30 -04:00
core::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack));
2022-10-31 22:42:01 -04:00
out
}
}
/// The args are in little endian order (first arg is lowest order)
2022-11-01 00:50:30 -04:00
#[inline(always)]
2022-10-31 22:42:01 -04:00
pub fn merge_m128_m256(a: arch::__m128i, b: arch::__m128i) -> arch::__m256i {
unsafe {
let out: _;
2022-11-01 00:50:30 -04:00
core::arch::asm!("vinserti128 {}, {:y}, {}, 0x1", lateout(ymm_reg) out, in(ymm_reg) a, in(xmm_reg) b, options(pure, nomem, preserves_flags, nostack));
2022-10-31 22:42:01 -04:00
out
2022-10-30 15:35:17 -04:00
}
2022-10-31 22:42:01 -04:00
}
macro_rules! extract_lohi_bytes {
(($mask:expr, $op12:ident, $op3:ident), $in:ident) => {{
2022-11-01 00:50:30 -04:00
const MASK: arch::__m128i = unsafe { core::mem::transmute($mask) };
2022-10-31 22:42:01 -04:00
unsafe {
let out: _;
2022-11-01 00:50:30 -04:00
core::arch::asm!(
//concat!("vmovdqa {mask}, xmmword ptr [rip + .", stringify!($mask), "]"),
"vextracti128 {inter}, {input:y}, 1",
concat!(stringify!($op12), " {inter}, {inter}, {mask}"),
concat!(stringify!($op12), " {output:x}, {input:x}, {mask}"),
concat!(stringify!($op3), " {output:x}, {output:x}, {inter}"),
mask = in(xmm_reg) MASK, input = in(ymm_reg) $in, output = lateout(xmm_reg) out, inter = out(xmm_reg) _,
options(pure, nomem, preserves_flags, nostack)
);
2022-10-31 22:42:01 -04:00
out
}
}};
}
2022-10-30 15:35:17 -04:00
#[inline(always)]
2022-10-31 22:42:01 -04:00
pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i {
extract_lohi_bytes!(([0xffu16; 8], vpand, vpackuswb), v)
2022-10-30 15:35:17 -04:00
}
#[inline(always)]
2022-10-31 22:42:01 -04:00
pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i {
2022-11-01 00:50:30 -04:00
extract_lohi_bytes!(
(
[0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0],
vpshufb,
vpunpcklqdq
),
v
)
2022-10-30 15:35:17 -04:00
}
pub trait SimdTestAnd {
/// Returns true if the result of the bitwise AND of `self` and `mask` is not all zero.
fn test_and_non_zero(self, mask: Self) -> bool;
}
pub trait SimdBitwise {
/// Returns the bitwise OR of `self` and `rhs`.
fn or(self, rhs: Self) -> Self;
/// Returns the bitwise AND of `self` and `rhs`.
fn and(self, rhs: Self) -> Self;
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m128i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
2022-11-01 00:50:30 -04:00
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack));
core::mem::transmute(out)
}
}
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m256i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
2022-11-01 00:50:30 -04:00
core::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = out(reg_byte) out, options(pure, nomem, nostack));
core::mem::transmute(out)
}
}
}
const USE_BITWISE_INTRINSICS: bool = true;
#[cfg(target_feature = "avx")]
impl SimdBitwise for arch::__m128i {
#[inline(always)]
fn or(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm_or_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
#[inline(always)]
fn and(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm_and_si128(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
}
#[cfg(target_feature = "avx2")]
impl SimdBitwise for arch::__m256i {
#[inline(always)]
fn or(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm256_or_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
#[inline(always)]
fn and(self, rhs: Self) -> Self {
unsafe {
if USE_BITWISE_INTRINSICS {
arch::_mm256_and_si256(self, rhs)
} else {
let out: _;
core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack));
out
}
}
}
}