diff --git a/Cargo.toml b/Cargo.toml index ca9c4a4..61fd8ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "integer-sqrt" description = """ An implementation of integer square root algorithm for primitive rust types""" -version = "0.1.2" +version = "0.1.3" authors = ["Richard Dodd "] include = ["src/**/*.rs", "Cargo.toml"] repository = "https://github.com/derekdreery/integer-sqrt-rs" diff --git a/src/lib.rs b/src/lib.rs index 0373629..0e637ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,13 +29,19 @@ pub trait IntegerSquareRoot { /// For negative numbers (`i` family) this function will panic on negative input /// /// [wiki_article]: https://en.wikipedia.org/wiki/Integer_square_root - fn integer_sqrt(&self) -> Self where Self: Sized { - self.integer_sqrt_checked().expect("cannot calculate square root of negative number") + fn integer_sqrt(&self) -> Self + where + Self: Sized, + { + self.integer_sqrt_checked() + .expect("cannot calculate square root of negative number") } /// Find the integer square root, returning `None` if the number is negative (this can never /// happen for unsigned types). - fn integer_sqrt_checked(&self) -> Option where Self: Sized; + fn integer_sqrt_checked(&self) -> Option + where + Self: Sized; } // This could be more optimized @@ -65,9 +71,11 @@ macro_rules! impl_isqrt { let mut result = 0; loop { result = result << 1; - let candidate_result = result + 1; - if candidate_result * candidate_result <= *self >> shift { - result = candidate_result; + let candidate_result: $t = result + 1; + if let Some(cr_square) = candidate_result.checked_mul(candidate_result) { + if cr_square <= *self >> shift { + result = candidate_result; + } } if shift == 0 { break; @@ -83,103 +91,80 @@ macro_rules! impl_isqrt { }; } - impl_isqrt!(usize, u128, u64, u32, u16, u8, isize, i128, i64, i32, i16, i8); - #[cfg(test)] mod tests { use super::IntegerSquareRoot; - use core::{u8, u16, u64, i8}; + use core::{i8, u16, u64, u8}; + + macro_rules! gen_tests { + ($($type:ty => $fn_name:ident),*) => { + $( + #[test] + fn $fn_name() { + let newton_raphson = |val, square| 0.5 * (val + (square / val as $type) as f64); + let max_sqrt = { + let square = <$type>::max_value(); + let mut value = (square as f64).sqrt(); + for _ in 0..2 { + value = newton_raphson(value, square); + } + let mut value = value as $type; + // make sure we are below the max value (this is how integer square + // root works) + if value.checked_mul(value).is_none() { + value -= 1; + } + value + }; + let tests: [($type, $type); 9] = [ + (0, 0), + (1, 1), + (2, 1), + (3, 1), + (4, 2), + (81, 9), + (80, 8), + (<$type>::max_value(), max_sqrt), + (<$type>::max_value() - 1, max_sqrt), + ]; + for &(in_, out) in tests.iter() { + assert_eq!(in_.integer_sqrt(), out, "in {}", in_); + } + } + )* + }; + } + + gen_tests! { + i8 => i8_test, + u8 => u8_test, + i16 => i16_test, + u16 => u16_test, + i32 => i32_test, + u32 => u32_test, + i64 => i64_test, + u64 => u64_test, + u128 => u128_test, + isize => isize_test, + usize => usize_test + } #[test] - fn u8_sqrt() { - let tests = [ - (0u8, 0u8), + fn i128_test() { + let tests: [(i128, i128); 8] = [ + (0, 0), + (1, 1), + (2, 1), + (3, 1), (4, 2), - (7, 2), (81, 9), (80, 8), - (u8::MAX, (u8::MAX as f64).sqrt() as u8), + (i128::max_value(), 13_043_817_825_332_782_212), ]; for &(in_, out) in tests.iter() { assert_eq!(in_.integer_sqrt(), out, "in {}", in_); } } - - #[test] - fn i8_sqrt() { - let tests = [ - (0i8, 0i8), - (4, 2), - (7, 2), - (81, 9), - (80, 8), - (i8::MAX, (i8::MAX as f64).sqrt() as i8), - ]; - for &(in_, out) in tests.iter() { - assert_eq!(in_.integer_sqrt(), out, "in {}", in_); - } - } - - #[test] - #[should_panic] - fn i8_sqrt_negative() { - (-12i8).integer_sqrt(); - } - - #[test] - fn u16_sqrt() { - let tests = [ - (0u16, 0u16), - (4, 2), - (7, 2), - (81, 9), - (80, 8), - (u16::MAX, (u16::MAX as f64).sqrt() as u16), - ]; - for &(in_, out) in tests.iter() { - assert_eq!(in_.integer_sqrt(), out, "in {}", in_); - } - } - - #[test] - fn u64_sqrt() { - let sqrt_max = 4_294_967_295; - let tests = [ - (0u64, 0u64), - (4, 2), - (7, 2), - (81, 9), - (80, 8), - (u64::MAX, sqrt_max), - ]; - for &(in_, out) in tests.iter() { - assert_eq!(in_.integer_sqrt(), out, "in {}", in_); - } - // checks to make sure we have the right number for u64::MAX.integer_sqrt() - // we can't use the same strategy as in previous tests as f64 is now not returning the - // correct floored integer - assert!(sqrt_max * sqrt_max <= u64::MAX); - // check that the next number's square overflows - assert!((sqrt_max + 1).checked_mul(sqrt_max + 1).is_none()); - } - - #[test] - fn u128_sqrt() { - let sqrt_max: u128 = 18_446_744_073_709_551_615; - let tests = [ - (0u128, 0u128), - (4, 2), - (7, 2), - (81, 9), - (80, 8), - (u128::max_value(), sqrt_max), - ]; - for &(in_, out) in tests.iter() { - assert_eq!(in_.integer_sqrt(), out, "in {}", in_); - } - assert!(sqrt_max * sqrt_max <= u128::max_value()); - assert!((sqrt_max + 1).checked_mul(sqrt_max + 1).is_none()); - } }