diff --git a/src/lib.rs b/src/lib.rs index 6c51469..08d38d0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,5 +12,5 @@ pub mod simd; -#[macro_use] pub mod util; +pub use util::{array_op, defer_impl, unroll}; diff --git a/src/simd.rs b/src/simd.rs index 2965b6e..b9b38b0 100644 --- a/src/simd.rs +++ b/src/simd.rs @@ -40,12 +40,15 @@ pub const DIGIT_BATCH_SIZE: usize = WIDE_BATCH_SIZE * 2; pub const GATHER_BATCH_SIZE: usize = DIGIT_BATCH_SIZE / 4; +#[cfg(feature = "trace-simd")] #[macro_export] macro_rules! __if_trace_simd { - ($( $tt:tt )*) => { - // disabled - //{ $( $tt )* } - }; + ($( $tt:tt )*) => { { $( $tt )* } }; +} +#[cfg(not(feature = "trace-simd"))] +#[macro_export] +macro_rules! __if_trace_simd { + ($( $tt:tt )*) => {}; } pub use __if_trace_simd as if_trace_simd; diff --git a/src/util.rs b/src/util.rs index 30b13f7..dee48b6 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,22 +1,44 @@ -use std::num::NonZeroUsize; +use core::mem::MaybeUninit; +use core::num::NonZeroUsize; + +#[doc(hidden)] +#[allow(non_snake_case)] +#[inline(always)] +pub const fn __array_op__uninit_array() -> [MaybeUninit; LEN] { + MaybeUninit::uninit_array() +} + +#[doc(hidden)] +#[allow(non_snake_case)] +#[inline(always)] +pub const unsafe fn __array_op__array_assume_init(array: [MaybeUninit; LEN]) -> [T; LEN] { + MaybeUninit::array_assume_init(array) +} /// Compile-time array operations. /// /// # Generation /// /// ```rust -/// const I_TIMES_2: [usize; 32] = array_op!(gen[32] |i| i * 2); +/// # use brisk::array_op; +/// const I_TIMES_2: [usize; 8] = array_op!(gen[8] |i| i * 2); +/// +/// assert_eq!(I_TIMES_2, [0, 2, 4, 6, 8, 10, 12, 14]); /// ``` /// /// # Mapping /// /// ```rust -/// const I_TIMES_2_PLUS_1: [usize; 32] = array_op!(map[32, I_TIMES_2] |i, x| x + 1); +/// # use brisk::array_op; +/// # const I_TIMES_2: [usize; 8] = array_op!(gen[8] |i| i * 2); +/// const I_TIMES_2_PLUS_1: [usize; 8] = array_op!(map[8, I_TIMES_2] |i, x| x + 1); +/// +/// assert_eq!(I_TIMES_2_PLUS_1, [1, 3, 5, 7, 9, 11, 13, 15]); /// ``` #[macro_export] -macro_rules! array_op { +macro_rules! __array_op { (gen[$len:expr] |$i:pat_param| $val:expr) => {{ - let mut out = ::core::mem::MaybeUninit::uninit_array::<$len>(); + let mut out = $crate::util::__array_op__uninit_array::<_, $len>(); let mut i = 0; while i < $len { out[i] = ::core::mem::MaybeUninit::new(match i { @@ -24,9 +46,8 @@ macro_rules! array_op { }); i += 1; } - #[allow(unused_unsafe)] unsafe { - ::core::mem::MaybeUninit::array_assume_init(out) + $crate::util::__array_op__array_assume_init(out) } }}; (map[$len:expr, $src:expr] |$i:pat_param, $s:pat_param| $val:expr) => {{ @@ -48,7 +69,7 @@ pub const fn cast_u8_u32(arr: [u8; N]) -> [u32; N] { } #[macro_export] -macro_rules! defer_impl { +macro_rules! __defer_impl { ( => $impl:ident; @@ -98,12 +119,10 @@ pub const unsafe fn cast(a: A) -> B { #[doc(hidden)] #[macro_export] macro_rules! __subst { - ([$( $ignore:tt )*], [$( $use:tt )*]) => { - $( $use )* - }; - } - -pub use __subst as subst; + ([$( $ignore:tt )*], [$( $use:tt )*]) => { + $( $use )* + }; +} #[inline(always)] pub const fn align_down_to(n: usize) -> usize { @@ -170,7 +189,7 @@ pub unsafe fn alloc_aligned(len: NonZeroUsize, align: usize) -> *mut u8 { } #[macro_export] -macro_rules! unroll { +macro_rules! __unroll { (let [$( $x:ident ),+] => |$y:pat_param| $expr:expr) => { $crate::unroll!(let [$( $x: ($x) ),+] => |$y| $expr); }; @@ -210,6 +229,11 @@ macro_rules! unroll { }; } +pub use __array_op as array_op; +pub use __defer_impl as defer_impl; +pub use __subst as subst; +pub use __unroll as unroll; + #[cfg(test)] mod test { use super::*;