Use improved algorithm that only uses shifts

This new algorithm (taken from Wikipedia) only uses shifts, addditions,
and subtrations. On my x86_64 machine, the benchmarks are over twice as
fast.

This also takes num-traits as a dependancy, so that the implementation
can be a normal generic function, instread of a macro.

Signed-off-by: Joe Richey <joerichey@google.com>
This commit is contained in:
Joe Richey 2020-09-07 20:10:09 -07:00
parent 260148e029
commit 0445b3e5c1
No known key found for this signature in database
GPG Key ID: 1DD6D05AA306C53F
2 changed files with 33 additions and 33 deletions

View File

@ -11,3 +11,5 @@ keywords = ["integer", "square", "root", "isqrt", "sqrt"]
categories = ["algorithms", "no-std"]
license = "Apache-2.0/MIT"
[dependencies]
num-traits = "0.2"

View File

@ -44,43 +44,41 @@ pub trait IntegerSquareRoot {
Self: Sized;
}
#[inline(always)]
fn integer_sqrt_impl<T: num_traits::PrimInt>(mut n: T) -> Option<T> {
// Hopefully this will be stripped for unsigned numbers (impossible condition)
if n < T::zero() {
return None;
}
// Compute bit, the largest power of 4 <= n
use core::mem::size_of;
let mut bit = T::one().unsigned_shl(size_of::<T>() as u32 * 8 - 2);
while bit > n {
bit = bit.unsigned_shr(2);
}
// Algorithm based on the implementation in:
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
// Note that result/bit are logically unsigned (even if T is signed).
let mut result = T::zero();
while bit != T::zero() {
if n >= (result + bit) {
n = n - (result + bit);
result = result.unsigned_shr(1) + bit;
} else {
result = result.unsigned_shr(1);
}
bit = bit.unsigned_shr(2);
}
Some(result)
}
macro_rules! impl_isqrt {
($($t:ty)*) => { $(
impl IntegerSquareRoot for $t {
#[allow(unused_comparisons)]
fn integer_sqrt_checked(&self) -> Option<Self> {
// Hopefully this will be stripped for unsigned numbers (impossible condition)
if *self < 0 {
return None
}
// Find greatest shift
let mut shift = 2;
let mut n_shifted = *self >> shift;
// We check for n_shifted being self, since some implementations of logical
// right shifting shift modulo the word size.
while n_shifted != 0 && n_shifted != *self {
shift = shift + 2;
n_shifted = self.wrapping_shr(shift);
}
shift = shift - 2;
// Find digits of result.
let mut result = 0;
loop {
result = result << 1;
let candidate_result: $t = result + 1;
if let Some(cr_square) = candidate_result.checked_mul(candidate_result) {
if cr_square <= *self >> shift {
result = candidate_result;
}
}
if shift == 0 {
break;
}
shift = shift.saturating_sub(2);
}
Some(result)
integer_sqrt_impl(*self)
}
}
)* };