アルゴリズム - 整数上の二分探索

概要

整数(のようなデータ)について Yes/No を返す述語 \(P\) があるとする: \[P \colon \mathbb Z \to \mathrm{Bool}.\]

そして今, この \(P\) はある整数 \(m\) があって,

  • \(n < m \implies P(n) = \mathrm{No}\)
  • \(n \geq m \implies P(n) = \mathrm{Yes}\)

を満たすとする. このとき, この整数 \(m\) を求めたい.

ただし, 次のような2つの値 \(l, r (l < r)\) が予め与えられるとする:

  • \(P(l) = \mathrm{No}\),
  • \(P(r) = \mathrm{Yes}\).

解法

区間 \((l,r]\)\(m\) を含んでいる. このことを不変条件に持つように上手く区間のサイズを半分にしてく. そのためには \(l, r\) の適当な中間値を持ってきて, それが \(P\) を満たすかをチェックするだけでいい. これを繰り返して, 区間のサイズがちょうど \(1\) になったとき, その要素が求める答え \(m\) である.

ところで \(l,r,m\) の乗ってるデータは, 中間値を取る操作 middle と, 区間のサイズが \(1\) であることをチェックする操作 close を必要とする. 逆に言えばこの二つさえあれば整数そのものに限らなくて良い. 例えば十分小さい値 eps を定めて close(l, r) = (r - l < eps) とすることで浮動小数点数であっても, 精度 epsm が求まる.

応用

整数として配列のインデックス (usize) を選び, prop を上手く作ることで, 昇順ソート済みの配列 xs 中に x がいくつあるか, x 以上/以下 がいくつあるか, などを対数時間で計算できる.

/// Algorithm - Binary Search
pub trait Integer
where
    Self: std::marker::Sized,
{
    fn close(range: std::ops::Range<Self>) -> bool;
    fn middle(range: std::ops::Range<Self>) -> Self;
}
macro_rules! define_integer {
    ($type:ty, $range:ident, $close_condition:expr, $middle_point:expr) => {
        impl Integer for $type {
            fn close($range: std::ops::Range<Self>) -> bool {
                $close_condition
            }
            fn middle($range: std::ops::Range<Self>) -> Self {
                $middle_point
            }
        }
    };
}
define_integer!(usize, r, r.start + 1 >= r.end, (r.start + r.end) / 2);
define_integer!(u32, r, r.start + 1 >= r.end, (r.start + r.end) / 2);
define_integer!(u64, r, r.start + 1 >= r.end, (r.start + r.end) / 2);
define_integer!(u128, r, r.start + 1 >= r.end, (r.start + r.end) / 2);
define_integer!(i32, r, r.start + 1 >= r.end, (r.start + r.end) / 2);
define_integer!(i64, r, r.start + 1 >= r.end, (r.start + r.end) / 2);
define_integer!(i128, r, r.start + 1 >= r.end, (r.start + r.end) / 2);
define_integer!(
    f32,
    r,
    r.start + 0.00000001 >= r.end,
    (r.start + r.end) / 2.0
);
define_integer!(
    f64,
    r,
    r.start + 0.00000001 >= r.end,
    (r.start + r.end) / 2.0
);

// the minimum index in range s.t. prop holds
pub fn binsearch<X: Integer + Copy>(range: std::ops::Range<X>, prop: &dyn Fn(X) -> bool) -> X {
    if prop(range.start) {
        range.start
    } else {
        let mut left = range.start;
        let mut right = range.end;
        while !X::close(left..right) {
            let mid = X::middle(left..right);
            if prop(mid) {
                right = mid;
            } else {
                left = mid;
            }
        }
        right
    }
}

#[cfg(test)]
mod test_binary_search {
    use crate::algorithm::binary_search::*;

    #[test]
    fn search_bound() {
        let v: Vec<i32> = (0..100).collect();
        assert_eq!(binsearch(0..100, &|i| v[i] > 50), 51);
        assert_eq!(binsearch(0..100, &|i| v[i] >= 0), 0);
        assert_eq!(binsearch(0..100, &|i| v[i] > 100), 100);
    }
}