diff --git a/.common/rust/src/main.rs b/.common/rust/src/main.rs index a931f20..66f78ac 100644 --- a/.common/rust/src/main.rs +++ b/.common/rust/src/main.rs @@ -40,31 +40,118 @@ fn main() { #[allow(dead_code)] mod math { - const MOD: i64 = 1_000_000_007; + #[derive(Copy, Clone, Default)] + pub struct Z(i64); - pub fn add(a: i64, b: i64) -> i64 { - (a + b) % MOD - } + impl Z { + const MOD: i64 = 1_000_000_007; - pub fn sub(a: i64, b: i64) -> i64 { - ((a - b) % MOD + MOD) % MOD - } - - pub fn mul(a: i64, b: i64) -> i64 { - (a * b) % MOD - } - - pub fn exp(b: i64, e: i64) -> i64 { - if e == 0 { - return 1; + pub fn new(x: i64) -> Z { + Z(x.rem_euclid(Z::MOD)) } - let half = exp(b, e / 2); - if e % 2 == 0 { - return mul(half, half); + pub fn pow(self, mut exp: u32) -> Z { + let mut ans = Z::new(1); + let mut base = self; + while exp > 0 { + if exp % 2 == 1 { + ans *= base; + } + base *= base; + exp >>= 1; + } + ans } - mul(half, mul(half, b)) + pub fn inv(self) -> Z { + assert_ne!(self.0, 0); + self.pow((Z::MOD - 2) as u32) + } + } + + impl std::ops::Neg for Z { + type Output = Z; + + fn neg(self) -> Z { + Z::new(-self.0) + } + } + + impl std::ops::Add for Z { + type Output = Z; + + fn add(self, rhs: Z) -> Z { + Z::new(self.0 + rhs.0) + } + } + + impl std::ops::Sub for Z { + type Output = Z; + + fn sub(self, rhs: Z) -> Z { + Z::new(self.0 - rhs.0) + } + } + + impl std::ops::Mul for Z { + type Output = Z; + + fn mul(self, rhs: Z) -> Z { + Z::new(self.0 * rhs.0) + } + } + + impl std::ops::Div for Z { + type Output = Z; + + fn div(self, rhs: Z) -> Z { + #![allow(clippy::suspicious_arithmetic_impl)] + self * rhs.inv() + } + } + + impl std::ops::AddAssign for Z { + fn add_assign(&mut self, rhs: Z) { + *self = *self + rhs; + } + } + + impl std::ops::SubAssign for Z { + fn sub_assign(&mut self, rhs: Z) { + *self = *self - rhs; + } + } + + impl std::ops::MulAssign for Z { + fn mul_assign(&mut self, rhs: Z) { + *self = *self * rhs; + } + } + + impl std::ops::DivAssign for Z { + fn div_assign(&mut self, rhs: Z) { + *self = *self / rhs; + } + } + + impl std::fmt::Display for Z { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl std::fmt::Debug for Z { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } + } + + impl std::str::FromStr for Z { + type Err = std::num::ParseIntError; + + fn from_str(s: &str) -> Result { + Ok(Z::new(s.parse()?)) + } } /// A trait implementing the unsigned bit shifts.