//! //! This module contains the single trait [`IntegerSquareRoot`] and implements it for primitive //! integer types. //! //! # Example //! //! ``` //! extern crate integer_sqrt; //! // `use` trait to get functionality //! use integer_sqrt::IntegerSquareRoot; //! //! # fn main() { //! assert_eq!(4u8.integer_sqrt(), 2); //! # } //! ``` //! //! [`IntegerSquareRoot`]: ./trait.IntegerSquareRoot.html #![no_std] #![feature(const_trait_impl)] #![feature(const_fn_trait_bound)] #![feature(const_option)] /// A trait implementing integer square root. pub trait IntegerSquareRoot { /// Find the integer square root. /// /// See [Integer_square_root on wikipedia][wiki_article] for more information (and also the /// source of this algorithm) /// /// # Panics /// /// For negative numbers (`i` family) this function will panic on negative input /// /// [wiki_article]: https://en.wikipedia.org/wiki/Integer_square_root #[default_method_body_is_const] fn integer_sqrt(&self) -> Self where Self: Sized, { self.integer_sqrt_checked() .expect("cannot calculate square root of negative number") } /// Find the integer square root, returning `None` if the number is negative (this can never /// happen for unsigned types). fn integer_sqrt_checked(&self) -> Option where Self: Sized; } macro_rules! integer_sqrt { ($ty:ty as $unsigned_ty:ty, $value:expr) => {{ static_assertions::assert_eq_size!($ty, $unsigned_ty); use core::cmp::Ordering; match $value { 0 => return Some(0), // Hopefully this will be stripped for unsigned numbers (impossible condition) v if v < 0 => return None, _ => {} } // Compute bit, the largest power of 4 <= n const ZERO: $ty = 0; let max_shift: u32 = ZERO.leading_zeros() - 1; let shift: u32 = (max_shift - $value.leading_zeros()) & !1; let mut bit = (1 as $unsigned_ty << shift) as $ty; // Algorithm based on the implementation in: // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2) // Note that result/bit are logically unsigned (even if T is signed). let mut n = $value; let mut result = 0; while bit != 0 { if n >= (result + bit) { n = n - (result + bit); result = (result as $unsigned_ty >> 1) as $ty + bit; } else { result = (result as $unsigned_ty >> 1) as $ty; } bit = (bit as $unsigned_ty >> 2) as $ty; } Some(result) }}; (impl const $ty:ty as $unsigned_ty:ty) => { impl const IntegerSquareRoot for $ty { fn integer_sqrt_checked(&self) -> Option { integer_sqrt!($ty as $unsigned_ty, *self) } } }; } integer_sqrt!(impl const u8 as u8); integer_sqrt!(impl const u16 as u16); integer_sqrt!(impl const u32 as u32); integer_sqrt!(impl const u64 as u64); integer_sqrt!(impl const u128 as u128); integer_sqrt!(impl const usize as usize); integer_sqrt!(impl const i8 as u8); integer_sqrt!(impl const i16 as u16); integer_sqrt!(impl const i32 as u32); integer_sqrt!(impl const i64 as u64); integer_sqrt!(impl const i128 as u128); integer_sqrt!(impl const isize as usize); #[cfg(test)] mod tests { use super::IntegerSquareRoot; use core::{i8, u16, u64, u8}; macro_rules! gen_tests { ($($type:ty => $fn_name:ident),*) => { $( #[test] fn $fn_name() { let newton_raphson = |val, square| 0.5 * (val + (square / val as $type) as f64); let max_sqrt = { let square = <$type>::max_value(); let mut value = (square as f64).sqrt(); for _ in 0..2 { value = newton_raphson(value, square); } let mut value = value as $type; // make sure we are below the max value (this is how integer square // root works) if value.checked_mul(value).is_none() { value -= 1; } value }; let tests: [($type, $type); 9] = [ (0, 0), (1, 1), (2, 1), (3, 1), (4, 2), (81, 9), (80, 8), (<$type>::max_value(), max_sqrt), (<$type>::max_value() - 1, max_sqrt), ]; for &(in_, out) in tests.iter() { assert_eq!(in_.integer_sqrt(), out, "in {}", in_); } } )* }; } gen_tests! { i8 => i8_test, u8 => u8_test, i16 => i16_test, u16 => u16_test, i32 => i32_test, u32 => u32_test, i64 => i64_test, u64 => u64_test, u128 => u128_test, isize => isize_test, usize => usize_test } #[test] fn i128_test() { let tests: [(i128, i128); 8] = [ (0, 0), (1, 1), (2, 1), (3, 1), (4, 2), (81, 9), (80, 8), (i128::max_value(), 13_043_817_825_332_782_212), ]; for &(in_, out) in tests.iter() { assert_eq!(in_.integer_sqrt(), out, "in {}", in_); } } }