diff --git a/Cargo.toml b/Cargo.toml index 61fd8ab..8aed8fe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,5 @@ keywords = ["integer", "square", "root", "isqrt", "sqrt"] categories = ["algorithms", "no-std"] license = "Apache-2.0/MIT" +[dependencies] +num-traits = "0.2" diff --git a/src/lib.rs b/src/lib.rs index ec0c129..bd65c40 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -44,43 +44,41 @@ pub trait IntegerSquareRoot { Self: Sized; } +#[inline(always)] +fn integer_sqrt_impl(mut n: T) -> Option { + // Hopefully this will be stripped for unsigned numbers (impossible condition) + if n < T::zero() { + return None; + } + + // Compute bit, the largest power of 4 <= n + use core::mem::size_of; + let mut bit = T::one().unsigned_shl(size_of::() as u32 * 8 - 2); + while bit > n { + bit = bit.unsigned_shr(2); + } + + // Algorithm based on the implementation in: + // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2) + // Note that result/bit are logically unsigned (even if T is signed). + let mut result = T::zero(); + while bit != T::zero() { + if n >= (result + bit) { + n = n - (result + bit); + result = result.unsigned_shr(1) + bit; + } else { + result = result.unsigned_shr(1); + } + bit = bit.unsigned_shr(2); + } + Some(result) +} + macro_rules! impl_isqrt { ($($t:ty)*) => { $( impl IntegerSquareRoot for $t { - #[allow(unused_comparisons)] fn integer_sqrt_checked(&self) -> Option { - // Hopefully this will be stripped for unsigned numbers (impossible condition) - if *self < 0 { - return None - } - // Find greatest shift - let mut shift = 2; - let mut n_shifted = *self >> shift; - // We check for n_shifted being self, since some implementations of logical - // right shifting shift modulo the word size. - while n_shifted != 0 && n_shifted != *self { - shift = shift + 2; - n_shifted = self.wrapping_shr(shift); - } - shift = shift - 2; - - // Find digits of result. - let mut result = 0; - loop { - result = result << 1; - let candidate_result: $t = result + 1; - if let Some(cr_square) = candidate_result.checked_mul(candidate_result) { - if cr_square <= *self >> shift { - result = candidate_result; - } - } - if shift == 0 { - break; - } - shift = shift.saturating_sub(2); - } - - Some(result) + integer_sqrt_impl(*self) } } )* };