brisk/src/util.rs

256 lines
7.5 KiB
Rust

use std::num::NonZeroUsize;
/// Compile-time array operations.
///
/// # Generation
///
/// ```rust
/// const I_TIMES_2: [usize; 32] = array_op!(gen[32] |i| i * 2);
/// ```
///
/// # Mapping
///
/// ```rust
/// const I_TIMES_2_PLUS_1: [usize; 32] = array_op!(map[32, I_TIMES_2] |i, x| x + 1);
/// ```
#[macro_export]
macro_rules! array_op {
(gen[$len:expr] |$i:pat_param| $val:expr) => {{
let mut out = ::core::mem::MaybeUninit::uninit_array::<$len>();
let mut i = 0;
while i < $len {
out[i] = ::core::mem::MaybeUninit::new(match i {
$i => $val,
});
i += 1;
}
#[allow(unused_unsafe)]
unsafe {
::core::mem::MaybeUninit::array_assume_init(out)
}
}};
(map[$len:expr, $src:expr] |$i:pat_param, $s:pat_param| $val:expr) => {{
$crate::array_op!(
gen[$len]
| i
| match i {
$i => match $src[i] {
$s => $val,
},
}
)
}};
}
#[inline]
pub const fn cast_u8_u32<const N: usize>(arr: [u8; N]) -> [u32; N] {
array_op!(map[N, arr] |_, v| v as u32)
}
#[macro_export]
macro_rules! defer_impl {
(
=> $impl:ident;
$( fn $name:ident($( $pname:ident: $pty:ty ),*) -> $rty:ty; )*
) => {
$(
#[inline(always)]
fn $name($( $pname: $pty ),*) -> $rty {
<$impl>::$name($( $pname ),*)
}
)*
};
}
#[inline(always)]
#[cold]
pub const fn cold() {}
#[inline(always)]
pub const fn likely(b: bool) -> bool {
core::intrinsics::likely(b)
}
#[inline(always)]
pub const fn unlikely(b: bool) -> bool {
core::intrinsics::unlikely(b)
}
/// Like transmute, but implemented via a union so that we can use it in situations where
/// transmute's "safety" restrictions are too strict or uninformed (i.e. we can prove it is safe
/// or we simply don't care).
#[inline(always)]
pub const unsafe fn cast<A, B>(a: A) -> B {
union Cast<A, B> {
a: core::mem::ManuallyDrop<A>,
b: core::mem::ManuallyDrop<B>,
}
core::mem::ManuallyDrop::into_inner(
Cast {
a: core::mem::ManuallyDrop::new(a),
}
.b,
)
}
#[doc(hidden)]
#[macro_export]
macro_rules! __subst {
([$( $ignore:tt )*], [$( $use:tt )*]) => {
$( $use )*
};
}
pub use __subst as subst;
#[inline(always)]
pub const fn align_down_to<const N: usize>(n: usize) -> usize {
let shift = match N.checked_ilog2() {
Some(x) => x,
None => 0,
};
return n >> shift << shift;
}
#[inline(always)]
pub const fn align_up_to<const N: usize>(n: usize) -> usize {
let shift = match N.checked_ilog2() {
Some(x) => x,
None => 0,
};
return (n + (N - 1)) >> shift << shift;
}
/// This function is unsafe because the caller must ensure that the value is not used until it has
/// been properly initialized.
#[cfg(feature = "alloc")]
#[inline(always)]
pub unsafe fn alloc_aligned_box<T, const ALIGN: usize>() -> Box<T> {
match NonZeroUsize::new(core::mem::size_of::<T>()) {
Some(size) => {
let ptr = alloc_aligned(
size,
core::cmp::max(core::mem::align_of::<T>(), ALIGN),
);
Box::from_raw(ptr as *mut T)
}
None => Box::from_raw(core::ptr::NonNull::dangling().as_ptr()),
}
}
/// This function is unsafe because the caller must ensure that the value is not used until it has
/// been properly initialized.
#[cfg(feature = "alloc")]
#[inline(always)]
pub unsafe fn alloc_aligned_box_slice<T, const ALIGN: usize>(len: usize) -> Box<[T]> {
if core::mem::size_of::<T>() == 0 || len == 0 {
return Box::from_raw(core::ptr::NonNull::from_raw_parts(core::ptr::NonNull::dangling(), 0).as_ptr());
}
let size = core::cmp::max(core::mem::size_of::<T>(), core::mem::align_of::<T>());
let ptr = alloc_aligned(NonZeroUsize::new_unchecked(size * len), core::cmp::max(core::mem::align_of::<T>(), ALIGN));
Box::from_raw(core::ptr::slice_from_raw_parts_mut(ptr as *mut T, len))
}
/// This function is unsafe because all the preconditions of
/// `std::alloc::Layout::from_size_align_unchecked` must be upheld by the caller.
#[cfg(feature = "alloc")]
#[inline(always)]
pub unsafe fn alloc_aligned(len: NonZeroUsize, align: usize) -> *mut u8 {
unsafe {
// SAFETY: len is nonzero
let layout = std::alloc::Layout::from_size_align_unchecked(len.get(), align);
let ptr = std::alloc::alloc(layout);
if ptr.is_null() {
std::alloc::handle_alloc_error(layout);
}
ptr
}
}
#[macro_export]
macro_rules! unroll {
(let [$( $x:ident ),+] => |$y:pat_param| $expr:expr) => {
$crate::unroll!(let [$( $x: ($x) ),+] => |$y| $expr);
};
(let [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
$crate::unroll!(@ [$] [$( $id: ($( $x ),+) ),+] => |$( $y ),+| $($expr)+)
};
(@ [$dollar:tt] [$( $id:ident: ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
macro_rules! __unrolled {
($dollar id:ident, $dollar ( $dollar z:tt ),+) => {
#[allow(unused_parens)]
let ($( $y ),+) = ($dollar ( $dollar z ),+);
let $dollar id = $($expr)+;
};
}
$( __unrolled!($id, $( $x ),+); )+
//$(
// let ($( $y ),+) = ($( $x ),+);
// $expr;
//)+
};
([$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
$crate::unroll!(@ [$] [$( ($( $x ),+) ),+] => |$( $y ),+| $($expr)+)
};
(@ [$dollar:tt] [$( ($( $x:expr ),+) ),+] => |$( $y:pat_param ),+| $($expr:tt)+) => {
macro_rules! __unrolled {
($dollar ( $dollar z:tt ),+) => {
#[allow(unused_parens)]
let ($( $y ),+) = ($dollar ( $dollar z ),+);
$($expr)+;
};
}
$( __unrolled!($( $x ),+); )+
//$(
// let ($( $y ),+) = ($( $x ),+);
// $expr;
//)+
};
}
#[cfg(test)]
mod test {
use super::*;
use crate::array_op;
#[test]
pub fn test_align_down_to() {
assert_eq!(align_down_to::<8>(8), 8);
assert_eq!(align_down_to::<16>(8), 0);
assert_eq!(align_down_to::<16>(16), 16);
assert_eq!(align_down_to::<16>(15), 0);
assert_eq!(align_down_to::<16>(17), 16);
}
#[test]
pub fn test_array_op_gen() {
assert_eq!(array_op!(gen[4] | i | i), [0, 1, 2, 3]);
assert_eq!(
array_op!(gen[16] | i | i),
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
);
assert_eq!(array_op!(gen[4] | i | i as u8), [0u8, 1, 2, 3]);
assert_eq!(
array_op!(gen[16] | i | i as u8),
[0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
);
const A: [u8; 4] = array_op!(gen[4] | i | i as u8);
const B: [u8; 16] = array_op!(gen[16] | i | i as u8);
assert_eq!(A, [0u8, 1, 2, 3]);
assert_eq!(B, [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]);
const LEN: usize = core::mem::size_of::<core::arch::x86_64::__m256i>();
const INDICES: [u8; LEN] = array_op!(gen[LEN] | i | (i as u8) >> 1);
assert_eq!(
INDICES,
[
0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12,
13, 13, 14, 14, 15, 15
]
);
}
}