diff --git a/.common/rust/src/main.rs b/.common/rust/src/main.rs index 70d2393..d8aa46b 100644 --- a/.common/rust/src/main.rs +++ b/.common/rust/src/main.rs @@ -66,6 +66,90 @@ mod math { mul(half, mul(half, b)) } + + /// A trait implementing the unsigned bit shifts. + pub trait UnsignedShift { + fn unsigned_shl(self, n: u32) -> Self; + fn unsigned_shr(self, n: u32) -> Self; + } + + /// A trait implementing the integer square root. + pub trait ISqrt { + fn isqrt(&self) -> Self + where + Self: Sized, + { + self.isqrt_checked() + .expect("cannot calculate square root of negative number") + } + + fn isqrt_checked(&self) -> Option + where + Self: Sized; + } + + macro_rules! math_traits_impl { + ($T:ty, $U: ty) => { + impl UnsignedShift for $T { + #[inline] + fn unsigned_shl(self, n: u32) -> Self { + ((self as $U) << n) as $T + } + + #[inline] + fn unsigned_shr(self, n: u32) -> Self { + ((self as $U) >> n) as $T + } + } + + impl ISqrt for $T { + #[inline] + fn isqrt_checked(&self) -> Option { + use core::cmp::Ordering; + match self.cmp(&<$T>::default()) { + // Hopefully this will be stripped for unsigned numbers (impossible condition) + Ordering::Less => return None, + Ordering::Equal => return Some(<$T>::default()), + _ => {} + } + + // Compute bit, the largest power of 4 <= n + let max_shift: u32 = <$T>::default().leading_zeros() - 1; + let shift: u32 = (max_shift - self.leading_zeros()) & !1; + let mut bit = <$T>::try_from(1).unwrap().unsigned_shl(shift); + + // Algorithm based on the implementation in: + // https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2) + // Note that result/bit are logically unsigned (even if T is signed). + let mut n = *self; + let mut result = <$T>::default(); + while bit != <$T>::default() { + if n >= (result + bit) { + n -= result + bit; + result = result.unsigned_shr(1) + bit; + } else { + result = result.unsigned_shr(1); + } + bit = bit.unsigned_shr(2); + } + Some(result) + } + } + }; + } + + math_traits_impl!(i8, u8); + math_traits_impl!(u8, u8); + math_traits_impl!(i16, u16); + math_traits_impl!(u16, u16); + math_traits_impl!(i32, u32); + math_traits_impl!(u32, u32); + math_traits_impl!(i64, u64); + math_traits_impl!(u64, u64); + math_traits_impl!(i128, u128); + math_traits_impl!(u128, u128); + math_traits_impl!(isize, usize); + math_traits_impl!(usize, usize); } #[allow(dead_code)]