自然数/整数 - 関数 - 二項係数 (ModInt)
概要
二項係数の剰余を取った値 \(\left(\begin{array}{c}n\\k\end{array}\right) \bmod M\) を計算量 \(O(n \log n)\) で計算する.
/// Number - Binomial Coefficient on ModInt
use crate::algebra::modint::*;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Binom {
n: u64,
k: u64,
coeff: ModInt,
}
impl Binom {
pub fn unwrap(&self) -> ModInt {
self.coeff
}
/// Calc Binom-Coeff with O(k)
pub fn new(n: u64, k: u64, modulo: i64) -> Self {
if k == 0 {
let coeff = ModInt(1, modulo);
Self { n, k, coeff }
} else if n < k || n == 0 {
let coeff = ModInt(0, modulo);
Self { n, k, coeff }
} else if n < k * 2 {
let mut m = Self::new(n, n - k, modulo);
m.k = k;
m
} else {
let mut c = ModInt(1, modulo);
for i in 0..k {
c *= (n - i) as i64;
c /= (k - i) as i64;
}
Self { n, k, coeff: c }
}
}
/// Calc `binom(n, k)` with a Hint
pub fn new_with_hint(n: u64, k: u64, hint: &Binom) -> Self {
if k == 0 {
let coeff = ModInt(1, hint.unwrap().1);
return Self { n, k, coeff };
}
if n < k || n == 0 {
let coeff = ModInt(0, hint.unwrap().1);
return Self { n, k, coeff };
}
if n == hint.n && k == hint.k {
return *hint;
}
let (n_next, k_next, c_next) = if n < hint.n && k < hint.k {
let c = hint.unwrap() * hint.k as i64 / hint.n as i64;
(hint.n - 1, hint.k - 1, c)
} else if n > hint.n && k > hint.k {
let c = hint.unwrap() * (hint.n + 1) as i64 / (hint.k + 1) as i64;
(hint.n + 1, hint.k + 1, c)
} else if n > hint.n {
let c = hint.unwrap() * (hint.n + 1) as i64 / (hint.n - hint.k + 1) as i64;
(hint.n + 1, hint.k, c)
} else if n < hint.n {
let c = hint.unwrap() * (hint.n - hint.k) as i64 / hint.n as i64;
(hint.n - 1, hint.k, c)
} else if k > hint.k {
let c = hint.unwrap() * (hint.n - hint.k) as i64 / (hint.k + 1) as i64;
(hint.n, hint.k + 1, c)
} else {
let c = hint.unwrap() * hint.k as i64 / (hint.n - hint.k + 1) as i64;
(hint.n, hint.k - 1, c)
};
let nexthint = Binom {
n: n_next,
k: k_next,
coeff: c_next,
};
Self::new_with_hint(n, k, &nexthint)
}
}
#[cfg(test)]
mod test_binom_modint {
use crate::num::binom_modint::*;
#[test]
fn it_works() {
const MOD: i64 = 1000000007;
assert_eq!(Binom::new(5, 0, MOD).unwrap().unwrap(), 1);
assert_eq!(Binom::new(5, 1, MOD).unwrap().unwrap(), 5);
assert_eq!(Binom::new(5, 2, MOD).unwrap().unwrap(), 10);
assert_eq!(Binom::new(5, 3, MOD).unwrap().unwrap(), 10);
assert_eq!(Binom::new(5, 4, MOD).unwrap().unwrap(), 5);
assert_eq!(Binom::new(5, 5, MOD).unwrap().unwrap(), 1);
}
#[test]
fn large_numbers() {
const MOD: i64 = 107;
assert_eq!(Binom::new(100, 50, MOD).unwrap().unwrap(), 35);
}
#[test]
fn test_with_hint() {
const MOD: i64 = 107;
let c = Binom::new(5, 2, MOD);
assert_eq!(Binom::new_with_hint(4, 2, &c), Binom::new(4, 2, MOD));
for n in 3..8 {
for k in 0..=n {
assert_eq!(Binom::new_with_hint(n, k, &c), Binom::new(n, k, MOD));
}
}
}
#[test]
fn test_erroneous() {
const MOD: i64 = 107;
assert_eq!(
Binom::new(0, 0, MOD),
Binom {
n: 0,
k: 0,
coeff: ModInt(1, MOD)
}
);
assert_eq!(
Binom::new(0, 1, MOD),
Binom {
n: 0,
k: 1,
coeff: ModInt(0, MOD)
}
);
assert_eq!(
Binom::new(1, 2, MOD),
Binom {
n: 1,
k: 2,
coeff: ModInt(0, MOD)
}
);
assert_eq!(
Binom::new_with_hint(1, 2, &Binom::new(1, 1, MOD)),
Binom {
n: 1,
k: 2,
coeff: ModInt(0, MOD)
}
);
}
}