diff --git a/src/lib.rs b/src/lib.rs index 8b82af9..95fce60 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,6 +27,9 @@ use core::simd::*; use alloc::{boxed::Box, vec::Vec}; use simd::SimdTestAnd as _; +use simd::SimdBitwise as _; + +use util::array_op; // use the maximum batch size that would be supported by AVX-512 //pub const SIMD_WIDTH: usize = 512; @@ -43,74 +46,24 @@ const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4; macro_rules! if_trace_simd { ($( $tt:tt )*) => { // disabled - //$( $tt )* + //{ $( $tt )* } }; } const VALIDATE: bool = true; -#[inline] -const fn alternating_mask(first_bias: bool) -> [bool; N] { - let mut mask = [false; N]; - let mut i = 0; - if first_bias { - while i < N / 2 { - mask[i * 2] = true; - i += 1; - } - } else { - while i < N / 2 { - mask[i * 2 + 1] = true; - i += 1; - } - } - mask -} - -#[inline] -const fn msb_lsb_indices() -> [usize; N] { - if N % 2 != 0 { - panic!("Illegal N"); - } - - let mut indices = [0; N]; - let mut i = 0; - while i < N / 2 { - indices[i] = i * 2; - indices[N / 2 + i] = i * 2 + 1; - i += 1; - } - - indices -} - #[inline] const fn alternating_indices(first_bias: bool) -> [usize; N] { - let mut indices = [0; N]; - let mut i = 0; if first_bias { - while i < N { - indices[i] = i * 2; - i += 1; - } + array_op!(gen[N] |i| i * 2) } else { - while i < N { - indices[i] = i * 2 + 1; - i += 1; - } + array_op!(gen[N] |i| i * 2 + 1) } - indices } #[inline] -const fn cast_usize_u8(arr: [usize; N]) -> [u8; N] { - let mut arr1 = [0u8; N]; - let mut i = 0; - while i < arr.len() { - arr1[i] = arr[i] as u8; - i += 1; - } - arr1 +const fn cast_u8_u32(arr: [u8; N]) -> [u32; N] { + array_op!(map[N, arr] |_, v| v as u32) } const MSB_INDICES: [usize; DIGIT_BATCH_SIZE / 2] = alternating_indices(true); @@ -121,9 +74,7 @@ pub const INVALID_BIT: u8 = 0b1000_0000; pub const WIDE_INVALID_BIT: u16 = 0b1000_1000_0000_0000; const ASCII_DIGITS: [u8; 256] = { - let mut digits = [0u8; 256]; - let mut i = u8::MIN; - while i < u8::MAX { + array_op!(gen[256] |i| { const DIGIT_MIN: u8 = '0' as u8; const DIGIT_MAX: u8 = '9' as u8; const LOWER_MIN: u8 = 'a' as u8; @@ -131,27 +82,17 @@ const ASCII_DIGITS: [u8; 256] = { const UPPER_MIN: u8 = 'A' as u8; const UPPER_MAX: u8 = 'F' as u8; - digits[i as usize] = match i { + let i = i as u8; + match i { DIGIT_MIN..=DIGIT_MAX => i - DIGIT_MIN, LOWER_MIN..=LOWER_MAX => 10 + i - LOWER_MIN, UPPER_MIN..=UPPER_MAX => 10 + i - UPPER_MIN, _ => INVALID_BIT, - }; - - i += 1; - } - digits + } + }) }; -const __ASCII_DIGITS_SIMD: [u32; 256] = { - let mut digits = [0u32; 256]; - let mut i = u8::MIN; - while i < u8::MAX { - digits[i as usize] = ASCII_DIGITS[i as usize] as u32; - i += 1; - } - digits -}; +const __ASCII_DIGITS_SIMD: [u32; 256] = cast_u8_u32(ASCII_DIGITS); const ASCII_DIGITS_SIMD: *const i32 = &__ASCII_DIGITS_SIMD as *const u32 as *const i32; @@ -194,7 +135,7 @@ where &ASCII_DIGITS, Mask::splat(true), ascii.cast(), - simd::splat_0::(), + simd::splat_0u8::(), ) } } @@ -255,7 +196,7 @@ impl HexByteSimdDecoder for HexByteDecoderA { fn decode_simd(hi_los: [u8; DIGIT_BATCH_SIZE]) -> Option> { let hex_digits = hex_digit_simd::(Simd::from_array(hi_los)); if ((hex_digits & simd::splat_n::(INVALID_BIT)) - .simd_ne(simd::splat_0::())) + .simd_ne(simd::splat_0u8::())) .any() { return None; @@ -347,46 +288,131 @@ macro_rules! hex_digits_simd_inline { let c = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, c, 4); let d = simd::arch::_mm256_i32gather_epi32(ASCII_DIGITS_SIMD, d, 4); - let a = Simd::::from(a).cast::(); - let b = Simd::::from(b).cast::(); - let c = Simd::::from(c).cast::(); - let d = Simd::::from(d).cast::(); + let a = Simd::::from(a).cast::(); + let b = Simd::::from(b).cast::(); + let c = Simd::::from(c).cast::(); + let d = Simd::::from(d).cast::(); - if_trace_simd! { - println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}"); - } + if_trace_simd! { + println!("{a:x?}, {b:x?}, {c:x?}, {d:x?}"); + } - // load the 64-bit integers into registers - let a = simd::load_u64_m128(util::cast(a)); - let b = simd::load_u64_m128(util::cast(b)); - let c = simd::load_u64_m128(util::cast(c)); - let d = simd::load_u64_m128(util::cast(d)); + // load the 64-bit integers into registers + let a = simd::load_u64_m128(util::cast(a)); + let b = simd::load_u64_m128(util::cast(b)); + let c = simd::load_u64_m128(util::cast(c)); + let d = simd::load_u64_m128(util::cast(d)); - { - let a = Simd::::from(a); - let b = Simd::::from(b); - let c = Simd::::from(c); - let d = Simd::::from(d); - if_trace_simd! { - println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}"); - } - } + if_trace_simd! { + let a = Simd::::from(a); + let b = Simd::::from(b); + let c = Simd::::from(c); + let d = Simd::::from(d); + println!("a,b,c,d: {a:x?}, {b:x?}, {c:x?}, {d:x?}"); + } - // copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer) - let ab = simd::merge_low_hi_m128(a, b); - // copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer) - let cd = simd::merge_low_hi_m128(c, d); + // copy the second 64-bit integer into the upper half of xmm0 (lower half is the first 64-bit integer) + let ab = simd::merge_lo_hi_m128(a, b); + // copy the fourth 64-bit integer into the upper half of xmm2 (lower half is the third 64-bit integer) + let cd = simd::merge_lo_hi_m128(c, d); - { - let ab = Simd::::from(ab); - let cd = Simd::::from(cd); - if_trace_simd! { - println!("ab,cd: {ab:x?}, {cd:x?}"); - } - } + if_trace_simd! { + let ab = Simd::::from(ab); + let cd = Simd::::from(cd); + println!("ab,cd: {ab:x?}, {cd:x?}"); + } - // merge the xmm0 and xmm1 (ymm1) registers into ymm0 - simd::merge_m128_m256(ab, cd) + // merge the xmm0 and xmm1 (ymm1) registers into ymm0 + let abcd = simd::merge_m128_m256(ab, cd); + + if_trace_simd! { + let abcd: Simd = abcd.into(); + println!("abcd: {abcd:x?}"); + } + + abcd + }}; +} + +macro_rules! merge_hex_digits_into_bytes_inline { + ($hex_digits:ident) => {{ + //let hex_digits = Simd::::from($hex_digits); + //let msb = simd_swizzle!(hex_digits, MSB_INDICES); + //let lsb = simd_swizzle!(hex_digits, LSB_INDICES); + let msb = simd::extract_lo_bytes($hex_digits); + let lsb = simd::extract_hi_bytes($hex_digits); + + let msb1: simd::arch::__m128i; + unsafe { std::arch::asm!("vpsllw {dst}, {src}, 4", src = in(xmm_reg) msb, dst = lateout(xmm_reg) msb1) }; + if_trace_simd! { + let msb1: Simd = msb1.into(); + println!("msb1: {msb1:x?}"); + } + //let msb2: simd::arch::__m128i; + //unsafe { std::arch::asm!("vpand {dst}, {src}, 0x0000_0000_0000_0000_0000_0000_0000_f0f0", src = in(xmm_reg) msb1, dst = lateout(xmm_reg) msb2) }; + let msb2 = msb1.and(Simd::from_array([0xf0f0u16; WIDE_BATCH_SIZE / 2]).into()); + let b = msb2.or(lsb); + + if_trace_simd! { + let msb: Simd = msb.into(); + let msb1: Simd = msb1.into(); + let msb2: Simd = msb2.into(); + let lsb: Simd = lsb.into(); + let b: Simd = b.into(); + + println!("| Packed | Msb | <<4 | & | Lsb | Bytes | |"); + Simd::::from($hex_digits) + .to_array() + .chunks(2) + .zip(msb.to_array()) + .zip(msb1.to_array()) + .zip(msb2.to_array()) + .zip(lsb.to_array()) + .zip(b.to_array()) + .for_each(|(((((chunk, msb), msb1), msb2), lsb), b)| { + println!( + "| {chunk:02x?} | {msb:x?} | {msb1:x?} | {msb2:x?} | {lsb:x?} | {b:02x?} | {ok} |", + chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16), + ok = if chunk[0] == msb && chunk[1] == lsb { + '✓' + } else { + '✗' + } + ); + }); + } + + b + + /*let msb = simd::extract_lo_bytes($hex_digits); + let lsb = simd::extract_hi_bytes($hex_digits); + + let msb: Simd = msb.into(); + let lsb: Simd = lsb.into(); + + if_trace_simd! { + println!("msb: {msb:x?}"); + println!("lsb: {lsb:x?}"); + println!("| Packed | Msb | Lsb | |"); + Simd::::from(hex_digits) + .to_array() + .chunks(2) + .zip(msb.to_array()) + .zip(lsb.to_array()) + .for_each(|((chunk, msb), lsb)| { + println!( + "| {chunk:02x?} | {msb:x?} | {lsb:x?} | {ok} |", + chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16), + ok = if chunk[0] == msb && chunk[1] == lsb { + '✓' + } else { + '✗' + } + ); + }); + } + //simd::merge_lo_hi_m128(msb, lsb) + simd::arch::__m128i::from((msb << simd::splat_n::(4)) | lsb)*/ }}; } @@ -406,16 +432,8 @@ impl HexByteSimdDecoder for HexByteDecoderB { if hex_digits.test_and_non_zero(simd::splat_n::(INVALID_BIT).into()) { return None; } - let hex_digits = Simd::::from(hex_digits); - let msb = simd_swizzle!(hex_digits, MSB_INDICES); - let lsb = simd_swizzle!(hex_digits, LSB_INDICES); - let mut v = Simd::from_array([0u8; WIDE_BATCH_SIZE]); - for (i, v) in v.as_mut_array().iter_mut().enumerate() { - let hi = unsafe { *msb.as_array().get_unchecked(i) }; - let lo = unsafe { *lsb.as_array().get_unchecked(i) }; - *v = (hi << 4) | lo; - } - Some(v) + + Some(merge_hex_digits_into_bytes_inline!(hex_digits).into()) } } @@ -424,94 +442,16 @@ pub type HBD = HexByteDecoderB; pub mod conv { use core::simd::{LaneCount, Simd, SupportedLaneCount}; - /*trait Size { - const N: usize; - } - - macro_rules! size_impl { - ($ident:ident($size:expr)) => { - struct $ident; - - impl Size for $ident { - const N: usize = $size; - } - }; - ($ident:ident<$size:ty>) => { - size_impl!($ident(core::mem::size_of::<$size>())); - }; - } - - struct SizeMul(core::marker::PhantomData); - - impl Size for SizeMul { - const N: usize = T::N * N; - } - - size_impl!(SizeU8); - size_impl!(SizeU16); - size_impl!(SizeU32); - size_impl!(SizeU64); - - trait SizeOf { - type Size: Size; - - //const SIZE: usize; - } - - //impl SizeOf for T { - // const SIZE: usize = core::mem::size_of::(); - //} - - macro_rules! size_of_impl { - ($type:ty = $size:ident) => { - impl SizeOf for $type { - type Size = $size; - } - }; - } - - size_of_impl!(u8 = SizeU8); - size_of_impl!([u8; 2] = SizeU16); - size_of_impl!(u16 = SizeU16); - size_of_impl!(u32 = SizeU32); - size_of_impl!(u64 = SizeU64);*/ - - #[allow(non_camel_case_types, non_snake_case)] - union u8_u16 - where - [u8; N_u16 * 2]:, - { - u8: [u8; N_u16 * 2], - u16: [u16; N_u16], - } - - #[allow(non_camel_case_types, non_snake_case)] - union u8x2_u8 - where - [u8; N_u8x2 * 2]:, - { - u8x2: [[u8; 2]; N_u8x2], - u8: [u8; N_u8x2 * 2], - } - - #[allow(non_camel_case_types, non_snake_case)] - union SimdU8_SimdU16 - where - LaneCount<{ N_U16 * 2 }>: SupportedLaneCount, - LaneCount: SupportedLaneCount, - { - SimdU8: Simd, - SimdU16: Simd, - } + use crate::util; #[inline(always)] pub const fn u8_to_u16(a: [u8; N_OUT * 2]) -> [u16; N_OUT] { - unsafe { u8_u16 { u8: a }.u16 } + unsafe { util::cast(a) } } #[inline(always)] pub const fn u8x2_to_u8(a: [[u8; 2]; N_IN]) -> [u8; N_IN * 2] { - unsafe { u8x2_u8 { u8x2: a }.u8 } + unsafe { util::cast(a) } } #[inline(always)] @@ -522,7 +462,7 @@ pub mod conv { LaneCount<{ N_OUT * 2 }>: SupportedLaneCount, LaneCount: SupportedLaneCount, { - unsafe { SimdU8_SimdU16 { SimdU8: a }.SimdU16 } + unsafe { util::cast(a) } } } @@ -598,7 +538,7 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit]) -> bo if VECTORED { use simd::arch; - let mut bad: arch::__m256i = simd::splat_0().into(); + let mut bad: arch::__m256i = simd::splat_0u8().into(); let mut i = 0; while i < util::align_down_to::(ascii.len()) { let hex_digits = unsafe { @@ -611,43 +551,11 @@ fn decode_hex_bytes_unchecked(ascii: &[u8], bytes: &mut [MaybeUninit]) -> bo core::arch::asm!("vpor {bad}, {digits}, {bad}", bad = inout(ymm_reg) bad, digits = in(ymm_reg) hex_digits, options(pure, nomem, preserves_flags, nostack)); } } - let hex_digits: Simd = hex_digits.into(); - if_trace_simd! { - println!("hex_digits: {hex_digits:x?}"); - } - let hex_digits: arch::__m256i = hex_digits.into(); - let msb = simd::extract_lo_bytes(hex_digits); - let lsb = simd::extract_hi_bytes(hex_digits); - - let msb: Simd = msb.into(); - let lsb: Simd = lsb.into(); - - if_trace_simd! { - println!("msb: {msb:x?}"); - println!("lsb: {lsb:x?}"); - println!("| Packed | Msb | Lsb | |"); - Simd::::from(hex_digits) - .to_array() - .chunks(2) - .zip(msb.to_array()) - .zip(lsb.to_array()) - .for_each(|((chunk, msb), lsb)| { - println!( - "| {chunk:02x?} | {msb:x?} | {lsb:x?} | {ok} |", - chunk = (chunk[0] as u16) << 4 | (chunk[1] as u16), - ok = if chunk[0] == msb && chunk[1] == lsb { - '✓' - } else { - '✗' - } - ); - }); - } - let buf = (msb << simd::splat_n::(4)) | lsb; + let buf = merge_hex_digits_into_bytes_inline!(hex_digits); //let buf: arch::__m128i = unsafe { util::cast(buf) }; - let buf: arch::__m128i = buf.into(); + //let buf: arch::__m128i = buf.into(); unsafe { // vmovaps xmm0, xmmword ptr [rsi] diff --git a/src/simd.rs b/src/simd.rs index 00ac24e..21acfbc 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -1,4 +1,4 @@ -use core::simd::{LaneCount, Simd, SupportedLaneCount}; +use core::simd::{LaneCount, Simd, SupportedLaneCount, SimdElement}; use crate::util::cast; @@ -19,15 +19,28 @@ pub use core::arch::x86 as arch; #[cfg(target_arch = "x86_64")] pub use core::arch::x86_64 as arch; +pub trait IsSimd { + type Lane; + + const LANES: usize; +} + +impl IsSimd for Simd where LaneCount: SupportedLaneCount, T: SimdElement { + type Lane = T; + + const LANES: usize = LANES; +} + macro_rules! specialized { ($( - $vis:vis fn $name:ident<$LANES:ident$(, $( $generics:tt )+)?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { + $vis:vis fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { $( $width:pat_param $( if $cfg:meta )? => $impl:expr ),+ $(,)? } )+) => {$( + #[allow(dead_code)] #[inline(always)] $vis fn $name($( $argn: $argt ),*) -> $rt $( where $( $where )* )? { // abusing const generics to specialize without the unsoundness of real specialization! @@ -39,12 +52,35 @@ macro_rules! specialized { } } )+}; + ($trait:ident for $ty:ty; + $( + fn $name:ident<$LANES:ident$(, [$( $generics:tt )+])?>($( $argn:ident: $argt:ty ),*) -> $rt:ty $(where [ $( $where:tt )* ])? { + $( + $width:pat_param $( if $cfg:meta )? => $impl:expr + ),+ + $(,)? + } + )+ + ) => { + impl $trait for $ty {$( + #[inline(always)] + fn $name$(<$( $generics )+>)?($( $argn: $argt ),*) -> $rt $( where $( $where )* )? { + // abusing const generics to specialize without the unsoundness of real specialization! + match $LANES { + $( + $( #[cfg( $cfg )] )? + $width => $impl + ),+ + } + } + )+} + }; } macro_rules! set1 { - ($arch:ident, $vec:ident, $reg:ident, $n:ident) => {{ + ($arch:ident, $inst:ident, $vec:ident, $reg:ident, $n:ident) => {{ let out: core::arch::$arch::$vec; - core::arch::asm!("vpbroadcastb {}, {}", lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack)); + core::arch::asm!(concat!(stringify!($inst), " {}, {}"), lateout($reg) out, in(xmm_reg) cast::<_, core::arch::$arch::__m128i>($n), options(pure, nomem, preserves_flags, nostack)); out }}; } @@ -54,10 +90,10 @@ specialized! { pub fn splat_n(n: u8) -> Simd where [ LaneCount: SupportedLaneCount, ] { - W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m128i, xmm_reg, n)) }, - W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m128i, xmm_reg, n)) }, - W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, __m256i, ymm_reg, n)) }, - W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, __m256i, ymm_reg, n)) }, + W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m128i, xmm_reg, n)) }, + W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m128i, xmm_reg, n)) }, + W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastb, __m256i, ymm_reg, n)) }, + W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastb, __m256i, ymm_reg, n)) }, // these are *terrible*. They compile to a bunch of MOVs and SETs W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi8(n as i8)) }, @@ -71,7 +107,28 @@ specialized! { _ => Simd::splat(n), } - pub fn splat_0() -> Simd where [ + // TODO: special case https://www.felixcloutier.com/x86/vpbroadcastb:vpbroadcastw:vpbroadcastd:vpbroadcastq + pub fn splat_u16(n: u16) -> Simd where [ + LaneCount: SupportedLaneCount, + ] { + W_128 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m128i, xmm_reg, n)) }, + W_128 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m128i, xmm_reg, n)) }, + W_256 if all(target_arch = "x86_64", target_feature = "avx2") => unsafe { cast(set1!(x86_64, vpbroadcastw, __m256i, ymm_reg, n)) }, + W_256 if all(target_arch = "x86", target_feature = "avx2") => unsafe { cast(set1!(x86, vpbroadcastw, __m256i, ymm_reg, n)) }, + + // these are *terrible*. They compile to a bunch of MOVs and SETs + W_128 if all(target_arch = "x86_64", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm_set1_epi16(n as i16)) }, + W_128 if all(target_arch = "x86", target_feature = "sse2", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm_set1_epi16(n as i16)) }, + W_256 if all(target_arch = "x86_64", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86_64::_mm256_set1_epi16(n as i16)) }, + W_256 if all(target_arch = "x86", target_feature = "avx", not(target_feature = "avx2")) => unsafe { cast(core::arch::x86::_mm256_set1_epi16(n as i16)) }, + + // I can't really test these, but they're documented as doing either a broadcast or the terrible approach mentioned above. + W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_set1_epi16(n as i16)) }, + W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_set1_epi16(n as i16)) }, + _ => Simd::splat(n), + } + + pub fn splat_0u8() -> Simd where [ LaneCount: SupportedLaneCount, ] { // these are fine, they are supposed to XOR themselves to zero out. @@ -81,7 +138,20 @@ specialized! { W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) }, W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) }, W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) }, - _ => splat_n(0), + _ => unsafe { crate::util::cast(splat_n(0)) }, + } + + pub fn splat_0u16() -> Simd where [ + LaneCount: SupportedLaneCount, + ] { + // these are fine, they are supposed to XOR themselves to zero out. + W_128 if all(target_arch = "x86_64", target_feature = "sse2") => unsafe { cast(core::arch::x86_64::_mm_setzero_si128()) }, + W_128 if all(target_arch = "x86", target_feature = "sse2") => unsafe { cast(core::arch::x86::_mm_setzero_si128()) }, + W_256 if all(target_arch = "x86_64", target_feature = "avx") => unsafe { cast(core::arch::x86_64::_mm256_setzero_si256()) }, + W_256 if all(target_arch = "x86", target_feature = "avx") => unsafe { cast(core::arch::x86::_mm256_setzero_si256()) }, + W_512 if all(target_arch = "x86_64", target_feature = "avx512f") => unsafe { cast(core::arch::x86_64::_mm512_setzero_si512()) }, + W_512 if all(target_arch = "x86", target_feature = "avx512f") => unsafe { cast(core::arch::x86::_mm512_setzero_si512()) }, + _ => unsafe { crate::util::cast(splat_n(0)) }, } } @@ -156,8 +226,9 @@ pub fn load_u64_m128(v: u64) -> arch::__m128i { } } +/// Merges the low halves of `a` and `b` into a single register like `ab`. #[inline(always)] -pub fn merge_low_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { +pub fn merge_lo_hi_m128(a: arch::__m128i, b: arch::__m128i) -> arch::__m128i { unsafe { // xmm0 = xmm1[0],xmm0[0] let out: _; @@ -217,6 +288,14 @@ pub trait SimdTestAnd { fn test_and_non_zero(self, mask: Self) -> bool; } +pub trait SimdBitwise { + /// Returns the bitwise OR of `self` and `rhs`. + fn or(self, rhs: Self) -> Self; + + /// Returns the bitwise AND of `self` and `rhs`. + fn and(self, rhs: Self) -> Self; +} + #[cfg(target_feature = "avx")] impl SimdTestAnd for arch::__m128i { #[inline(always)] @@ -240,3 +319,63 @@ impl SimdTestAnd for arch::__m256i { } } } + +const USE_BITWISE_INTRINSICS: bool = true; + +#[cfg(target_feature = "avx")] +impl SimdBitwise for arch::__m128i { + #[inline(always)] + fn or(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm_or_si128(self, rhs) + } else { + let out: _; + core::arch::asm!("vpor {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } + + #[inline(always)] + fn and(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm_and_si128(self, rhs) + } else { + let out: _; + core::arch::asm!("vpand {out}, {a}, {b}", a = in(xmm_reg) self, b = in(xmm_reg) rhs, out = out(xmm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } +} + +#[cfg(target_feature = "avx2")] +impl SimdBitwise for arch::__m256i { + #[inline(always)] + fn or(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm256_or_si256(self, rhs) + } else { + let out: _; + core::arch::asm!("vpor {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } + + #[inline(always)] + fn and(self, rhs: Self) -> Self { + unsafe { + if USE_BITWISE_INTRINSICS { + arch::_mm256_and_si256(self, rhs) + } else { + let out: _; + core::arch::asm!("vpand {out}, {a}, {b}", a = in(ymm_reg) self, b = in(ymm_reg) rhs, out = out(ymm_reg) out, options(pure, nomem, preserves_flags, nostack)); + out + } + } + } +} diff --git a/src/util.rs b/src/util.rs index c010fa5..dc97759 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,3 +1,22 @@ +#[doc(hidden)] +#[macro_export] +macro_rules! __array_op { + (gen[$len:expr] |$i:pat_param| $val:expr) => {{ + let mut out = std::mem::MaybeUninit::uninit_array(); + let mut i = 0; + while i < $len { + out[i] = std::mem::MaybeUninit::new(match i { $i => $val }); + i += 1; + } + unsafe { std::mem::MaybeUninit::array_assume_init(out) } + }}; + (map[$len:expr, $src:expr] |$i:pat_param, $s:pat_param| $val:expr) => {{ + $crate::util::array_op!(gen[$len] |i| match i { $i => match $src[i] { $s => $val } }) + }}; +} + +pub use __array_op as array_op; + #[doc(hidden)] #[macro_export] macro_rules! __defer_impl { @@ -19,10 +38,10 @@ pub use __defer_impl as defer_impl; #[inline(always)] #[cold] -pub fn cold() {} +pub const fn cold() {} #[inline] -pub fn likely(b: bool) -> bool { +pub const fn likely(b: bool) -> bool { if !b { cold() } @@ -30,7 +49,7 @@ pub fn likely(b: bool) -> bool { } #[inline] -pub fn unlikely(b: bool) -> bool { +pub const fn unlikely(b: bool) -> bool { if b { cold() } @@ -40,7 +59,7 @@ pub fn unlikely(b: bool) -> bool { /// Like transmute, but implemented via a union so that we can use it in situations where /// transmute's "safety" restrictions are too strict and uninformed (i.e. we can prove it is safe). #[inline(always)] -pub unsafe fn cast(a: A) -> B { +pub const unsafe fn cast(a: A) -> B { union Cast { a: core::mem::ManuallyDrop, b: core::mem::ManuallyDrop,