182 lines
5.6 KiB
Rust
182 lines
5.6 KiB
Rust
//!
|
|
//! This module contains the single trait [`IntegerSquareRoot`] and implements it for primitive
|
|
//! integer types.
|
|
//!
|
|
//! # Example
|
|
//!
|
|
//! ```
|
|
//! extern crate integer_sqrt;
|
|
//! // `use` trait to get functionality
|
|
//! use integer_sqrt::IntegerSquareRoot;
|
|
//!
|
|
//! # fn main() {
|
|
//! assert_eq!(4u8.integer_sqrt(), 2);
|
|
//! # }
|
|
//! ```
|
|
//!
|
|
//! [`IntegerSquareRoot`]: ./trait.IntegerSquareRoot.html
|
|
#![no_std]
|
|
#![feature(const_trait_impl)]
|
|
#![feature(const_fn_trait_bound)]
|
|
#![feature(const_option)]
|
|
|
|
/// A trait implementing integer square root.
|
|
pub trait IntegerSquareRoot {
|
|
/// Find the integer square root.
|
|
///
|
|
/// See [Integer_square_root on wikipedia][wiki_article] for more information (and also the
|
|
/// source of this algorithm)
|
|
///
|
|
/// # Panics
|
|
///
|
|
/// For negative numbers (`i` family) this function will panic on negative input
|
|
///
|
|
/// [wiki_article]: https://en.wikipedia.org/wiki/Integer_square_root
|
|
#[default_method_body_is_const]
|
|
fn integer_sqrt(&self) -> Self
|
|
where
|
|
Self: Sized,
|
|
{
|
|
self.integer_sqrt_checked()
|
|
.expect("cannot calculate square root of negative number")
|
|
}
|
|
|
|
/// Find the integer square root, returning `None` if the number is negative (this can never
|
|
/// happen for unsigned types).
|
|
fn integer_sqrt_checked(&self) -> Option<Self>
|
|
where
|
|
Self: Sized;
|
|
}
|
|
|
|
macro_rules! integer_sqrt {
|
|
($ty:ty as $unsigned_ty:ty, $value:expr) => {{
|
|
static_assertions::assert_eq_size!($ty, $unsigned_ty);
|
|
use core::cmp::Ordering;
|
|
match $value {
|
|
0 => return Some(0),
|
|
// Hopefully this will be stripped for unsigned numbers (impossible condition)
|
|
v if v < 0 => return None,
|
|
_ => {}
|
|
}
|
|
|
|
// Compute bit, the largest power of 4 <= n
|
|
const ZERO: $ty = 0;
|
|
let max_shift: u32 = ZERO.leading_zeros() - 1;
|
|
let shift: u32 = (max_shift - $value.leading_zeros()) & !1;
|
|
let mut bit = (1 as $unsigned_ty << shift) as $ty;
|
|
|
|
// Algorithm based on the implementation in:
|
|
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
|
|
// Note that result/bit are logically unsigned (even if T is signed).
|
|
let mut n = $value;
|
|
let mut result = 0;
|
|
while bit != 0 {
|
|
if n >= (result + bit) {
|
|
n = n - (result + bit);
|
|
result = (result as $unsigned_ty >> 1) as $ty + bit;
|
|
} else {
|
|
result = (result as $unsigned_ty >> 1) as $ty;
|
|
}
|
|
bit = (bit as $unsigned_ty >> 2) as $ty;
|
|
}
|
|
Some(result)
|
|
}};
|
|
(impl const $ty:ty as $unsigned_ty:ty) => {
|
|
impl const IntegerSquareRoot for $ty {
|
|
fn integer_sqrt_checked(&self) -> Option<Self> {
|
|
integer_sqrt!($ty as $unsigned_ty, *self)
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
integer_sqrt!(impl const u8 as u8);
|
|
integer_sqrt!(impl const u16 as u16);
|
|
integer_sqrt!(impl const u32 as u32);
|
|
integer_sqrt!(impl const u64 as u64);
|
|
integer_sqrt!(impl const u128 as u128);
|
|
integer_sqrt!(impl const usize as usize);
|
|
|
|
integer_sqrt!(impl const i8 as u8);
|
|
integer_sqrt!(impl const i16 as u16);
|
|
integer_sqrt!(impl const i32 as u32);
|
|
integer_sqrt!(impl const i64 as u64);
|
|
integer_sqrt!(impl const i128 as u128);
|
|
integer_sqrt!(impl const isize as usize);
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::IntegerSquareRoot;
|
|
use core::{i8, u16, u64, u8};
|
|
|
|
macro_rules! gen_tests {
|
|
($($type:ty => $fn_name:ident),*) => {
|
|
$(
|
|
#[test]
|
|
fn $fn_name() {
|
|
let newton_raphson = |val, square| 0.5 * (val + (square / val as $type) as f64);
|
|
let max_sqrt = {
|
|
let square = <$type>::max_value();
|
|
let mut value = (square as f64).sqrt();
|
|
for _ in 0..2 {
|
|
value = newton_raphson(value, square);
|
|
}
|
|
let mut value = value as $type;
|
|
// make sure we are below the max value (this is how integer square
|
|
// root works)
|
|
if value.checked_mul(value).is_none() {
|
|
value -= 1;
|
|
}
|
|
value
|
|
};
|
|
let tests: [($type, $type); 9] = [
|
|
(0, 0),
|
|
(1, 1),
|
|
(2, 1),
|
|
(3, 1),
|
|
(4, 2),
|
|
(81, 9),
|
|
(80, 8),
|
|
(<$type>::max_value(), max_sqrt),
|
|
(<$type>::max_value() - 1, max_sqrt),
|
|
];
|
|
for &(in_, out) in tests.iter() {
|
|
assert_eq!(in_.integer_sqrt(), out, "in {}", in_);
|
|
}
|
|
}
|
|
)*
|
|
};
|
|
}
|
|
|
|
gen_tests! {
|
|
i8 => i8_test,
|
|
u8 => u8_test,
|
|
i16 => i16_test,
|
|
u16 => u16_test,
|
|
i32 => i32_test,
|
|
u32 => u32_test,
|
|
i64 => i64_test,
|
|
u64 => u64_test,
|
|
u128 => u128_test,
|
|
isize => isize_test,
|
|
usize => usize_test
|
|
}
|
|
|
|
#[test]
|
|
fn i128_test() {
|
|
let tests: [(i128, i128); 8] = [
|
|
(0, 0),
|
|
(1, 1),
|
|
(2, 1),
|
|
(3, 1),
|
|
(4, 2),
|
|
(81, 9),
|
|
(80, 8),
|
|
(i128::max_value(), 13_043_817_825_332_782_212),
|
|
];
|
|
for &(in_, out) in tests.iter() {
|
|
assert_eq!(in_.integer_sqrt(), out, "in {}", in_);
|
|
}
|
|
}
|
|
}
|