chore(rs): add unsigned shifts and ‹isqrt›

Signed-off-by: Matej Focko <me@mfocko.xyz>
This commit is contained in:
Matej Focko 2023-07-24 23:07:50 +02:00
parent c78900b631
commit 7e4ed33c34
Signed by: mfocko
GPG key ID: 7C47D46246790496

View file

@ -66,6 +66,90 @@ mod math {
mul(half, mul(half, b))
}
/// A trait implementing the unsigned bit shifts.
pub trait UnsignedShift {
fn unsigned_shl(self, n: u32) -> Self;
fn unsigned_shr(self, n: u32) -> Self;
}
/// A trait implementing the integer square root.
pub trait ISqrt {
fn isqrt(&self) -> Self
where
Self: Sized,
{
self.isqrt_checked()
.expect("cannot calculate square root of negative number")
}
fn isqrt_checked(&self) -> Option<Self>
where
Self: Sized;
}
macro_rules! math_traits_impl {
($T:ty, $U: ty) => {
impl UnsignedShift for $T {
#[inline]
fn unsigned_shl(self, n: u32) -> Self {
((self as $U) << n) as $T
}
#[inline]
fn unsigned_shr(self, n: u32) -> Self {
((self as $U) >> n) as $T
}
}
impl ISqrt for $T {
#[inline]
fn isqrt_checked(&self) -> Option<Self> {
use core::cmp::Ordering;
match self.cmp(&<$T>::default()) {
// Hopefully this will be stripped for unsigned numbers (impossible condition)
Ordering::Less => return None,
Ordering::Equal => return Some(<$T>::default()),
_ => {}
}
// Compute bit, the largest power of 4 <= n
let max_shift: u32 = <$T>::default().leading_zeros() - 1;
let shift: u32 = (max_shift - self.leading_zeros()) & !1;
let mut bit = <$T>::try_from(1).unwrap().unsigned_shl(shift);
// Algorithm based on the implementation in:
// https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)
// Note that result/bit are logically unsigned (even if T is signed).
let mut n = *self;
let mut result = <$T>::default();
while bit != <$T>::default() {
if n >= (result + bit) {
n -= result + bit;
result = result.unsigned_shr(1) + bit;
} else {
result = result.unsigned_shr(1);
}
bit = bit.unsigned_shr(2);
}
Some(result)
}
}
};
}
math_traits_impl!(i8, u8);
math_traits_impl!(u8, u8);
math_traits_impl!(i16, u16);
math_traits_impl!(u16, u16);
math_traits_impl!(i32, u32);
math_traits_impl!(u32, u32);
math_traits_impl!(i64, u64);
math_traits_impl!(u64, u64);
math_traits_impl!(i128, u128);
math_traits_impl!(u128, u128);
math_traits_impl!(isize, usize);
math_traits_impl!(usize, usize);
}
#[allow(dead_code)]