fast-hex/src/simd.rs

233 lines
11 KiB
Rust

use std::simd::{LaneCount, Simd, SupportedLaneCount};
use crate::util::cast;
const W_128: usize = 128 / 8;
const W_256: usize = 256 / 8;
const W_512: usize = 512 / 8;
#[cfg(target_arch = "aarch64")]
use std::arch::aarch64 as arch;
#[cfg(target_arch = "arm")]
use std::arch::arm as arch;
#[cfg(target_arch = "wasm32")]
use std::arch::wasm32 as arch;
#[cfg(target_arch = "wasm64")]
use std::arch::wasm64 as arch;
#[cfg(target_arch = "x86")]
use std::arch::x86 as arch;
#[cfg(target_arch = "x86_64")]
use std::arch::x86_64 as arch;
macro_rules! specialized {
($(
$vis:vis fn $name:ident<$LANES:ident$(, $( $generics:tt )+)?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? {
$(
$width:pat_param $( if $cfg:meta )? => $impl:expr
),+
$(,)?
}
)+) => {$(
#[inline(always)]
$vis fn $name<const $LANES: usize$(, $( $generics )+)?>($( $argn: $argt ),*) -> $rt $( where $( $where )* )? {
// abusing const generics to specialize without the unsoundness of real specialization!
match $LANES {
$(
$( #[cfg( $cfg )] )?
$width => $impl
),+
}
}
)+};
}
macro_rules! set1 {
($arch:ident, $vec:ident, $reg:ident, $n:ident) => {{
let out: std::arch::$arch::$vec;
std::arch::asm!("vpbroadcastb {}, {}", lateout($reg) out, in(xmm_reg) cast::<_, std::arch::$arch::__m128i>($n));
out
}};
}
specialized! {
// TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq
pub fn splat_n<LANES>(n: u8) -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m128i, xmm_reg, n)) },
W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m128i, xmm_reg, n)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m256i, ymm_reg, n)) },
W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m256i, ymm_reg, n)) },
// these are *terrible*. They compile to a bunch of MOVs and SETs
W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) },
W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi8(n as i8)) },
W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi8(n as i8)) },
// I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above.
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi8(n as i8)) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi8(n as i8)) },
_ => Simd::splat(n),
}
pub fn splat_0<LANES>() -> Simd<u8, LANES> where [
LaneCount<LANES>: SupportedLaneCount,
] {
// these are fine, they are supposed to XOR themselves to zero out.
W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) },
W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) },
W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) },
W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) },
W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) },
W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) },
_ => splat_n(0),
}
}
/// Defines the indices used by [`swizzle`].
#[macro_export]
macro_rules! __swizzle_indices {
($name:ident = [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
std::arch::global_asm!(concat!(".", stringify!($name), ":")
$( , concat!("\n .byte ", stringify!($index)) )+
$( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )?);
};
}
#[macro_export]
macro_rules! __swizzle {
/*(xmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
$crate::simd::swizzle!(@ xmm_reg, $src, $dest, (xmmword) [$( $index ),+] $( , [$( $padding )+] )?)
};
(ymm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
$crate::simd::swizzle!(@ ymm_reg, $src, $dest, (ymmword) [$( $index ),+] $( , [$( $padding )+] )?)
};
(zmm_reg, $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
$crate::simd::swizzle!(@ zmm_reg, $src, $dest, (zmmword) [$( $index ),+] $( , [$( $padding )+] )?)
};*/
(xmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => {
$crate::simd::swizzle!(@ xmm_reg, x, $src, $mode $dest, (xmmword) $indices)
};
(ymm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => {
$crate::simd::swizzle!(@ ymm_reg, y, $src, $mode $dest, (ymmword) $indices)
};
(zmm_reg, $src:expr, $mode:ident $dest:expr, $indices:ident) => {
$crate::simd::swizzle!(@ zmm_reg z, $src, $mode $dest, (zmmword) $indices)
};
($reg:ident, $src:expr, $dest:expr, mem $indices:expr) => {
std::arch::asm!("vpshufb {}, {}, [{}]", in($reg) $src, lateout($reg) $dest, in(reg) $indices);
};
($reg:ident, $src:expr, $dest:expr, data($indices_reg:ident) $indices:expr) => {
std::arch::asm!("vpshufb {}, {}, {}", in($reg) $src, lateout($reg) $dest, in($indices_reg) $indices);
};
//(@ $reg:ident, $src:expr, $dest:expr, ($indices_reg:ident) [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
(@ $reg:ident, $token:ident, $src:expr, $mode:ident $dest:expr, ($indices_reg:ident) $indices:ident) => {
std::arch::asm!(concat!("vpshufb {:", stringify!($token), "}, {:", stringify!($token), "}, ", stringify!($indices_reg), " ptr [rip + .", stringify!($indices), "]"), $mode($reg) $dest, in($reg) $src);
// std::arch::asm!("2:"
// $( , concat!("\n .byte ", stringify!($index)) )+
// $( $( , $crate::util::subst!([$padding], ["\n .zero 1"]) )+ )?
// , "\n3:\n", concat!(" vpshufb {}, {}, ", stringify!($indices_reg), " ptr [rip + 2b]"), in($reg) $src, lateout($reg) $dest)
};
// ($src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:tt )+] )?) => {
// $crate::simd::swizzle!(@ $src, $dest, [$( stringify!($index) ),+] $( , [$( "\n ", subst!($padding, ""), "zero 1" )+] )?)
// };
// (@ $src:expr, $dest:expr, [$( $index:literal ),+] $( , [$( $padding:literal )+] )?) => {
// std::arch::asm!(r#"
// .indices:"#,
// $( "\n .byte ", $index ),+
// $( $( $padding ),+ )?
// r#"
// lsb:
// vpshufb {}, {}, xmmword ptr [rip + .indices]
// "#, in(xmm_reg) $src, lateout(xmm_reg) $dest)
// };
}
pub use __swizzle as swizzle;
pub use __swizzle_indices as swizzle_indices;
#[inline(always)]
pub fn load_u64_m128(v: u64) -> arch::__m128i {
unsafe {
let out: _;
std::arch::asm!("vmovq {}, {}", lateout(xmm_reg) out, in(reg) v);
out
}
}
#[inline(always)]
pub fn merge_low_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i {
unsafe {
// xmm0 = xmm1[0],xmm0[0]
let out: _;
std::arch::asm!("vpunpcklqdq {}, {}, {}", lateout(xmm_reg) out, in(xmm_reg) a, in(xmm_reg) b);
out
}
}
/// The args are in little endian order (first arg is lowest order)
pub fn merge_m128_m256(a: arch::__m128i, b: arch::__m128i) -> arch::__m256i {
unsafe {
let out: _;
std::arch::asm!("vinserti128 {}, {:y}, {}, 0x1", lateout(ymm_reg) out, in(ymm_reg) a, in(xmm_reg) b);
out
}
}
macro_rules! extract_lohi_bytes {
(($mask:expr, $op12:ident, $op3:ident), $in:ident) => {{
const MASK: arch::__m128i = unsafe { std::mem::transmute($mask) };
unsafe {
let out: _;
std::arch::asm!(
//concat!("vmovdqa {mask}, xmmword ptr [rip + .", stringify!($mask), "]"),
"vextracti128 {inter}, {input:y}, 1",
concat!(stringify!($op12), " {inter}, {inter}, {mask}"),
concat!(stringify!($op12), " {output:x}, {input:x}, {mask}"),
concat!(stringify!($op3), " {output:x}, {output:x}, {inter}"),
mask = in(xmm_reg) MASK, input = in(ymm_reg) $in, output = lateout(xmm_reg) out, inter = out(xmm_reg) _);
out
}
}};
}
#[inline(always)]
pub fn extract_lo_bytes(v: arch::__m256i) -> arch::__m128i {
extract_lohi_bytes!(([0xffu16; 8], vpand, vpackuswb), v)
}
#[inline(always)]
pub fn extract_hi_bytes(v: arch::__m256i) -> arch::__m128i {
extract_lohi_bytes!(([0x1u8, 0x3, 0x5, 0x7, 0x9, 0xb, 0xd, 0xf, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0], vpshufb, vpunpcklqdq), v)
}
pub trait SimdTestAnd {
/// Returns true if the result of the bitwise AND of `self` and `mask` is not all zero.
fn test_and_non_zero(self, mask: Self) -> bool;
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m128i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
std::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(xmm_reg) self, b = in(xmm_reg) mask, out = out(reg_byte) out);
std::mem::transmute(out)
}
}
}
#[cfg(target_feature = "avx")]
impl SimdTestAnd for arch::__m256i {
#[inline(always)]
fn test_and_non_zero(self, mask: Self) -> bool {
unsafe {
let out: u8;
std::arch::asm!("vptest {a}, {b}", "jz 2f", "mov {out}, 1", "jnz 3f", "2:", "mov {out}, 0", "3:", a = in(ymm_reg) self, b = in(ymm_reg) mask, out = out(reg_byte) out);
std::mem::transmute(out)
}
}
}