アルゴリズム - FFT - 数列の畳み込み

概要

2つの数列 \(A_i, B_i\) について, \[C_k = \sum_{i=0}^k A_i B_{k-i}\] なる数列 \(C_k\) を計算する.

ただし \(A_0=B_0=0\) であるとし, \(C_0=0\).

参考

/// Algorithm - Fast Fourier Transform
use crate::algebra::complex::*;
use crate::algebra::monoid::*;

pub struct FFT;
impl FFT {
    fn _fft(f: &Vec<Complex<f64>>, dir: f64) -> Vec<Complex<f64>> {
        let n = f.len();
        if n < 2 {
            return f.clone();
        }
        let f0: Vec<Complex<f64>> = (0..n / 2).map(|i| f[i * 2].clone()).collect();
        let f1: Vec<Complex<f64>> = (0..n / 2).map(|i| f[i * 2 + 1].clone()).collect();
        let g0 = FFT::_fft(&f0, dir);
        let g1 = FFT::_fft(&f1, dir);
        let pi = (-1.0_f64).acos();
        let theta = 2.0 * pi / (n as f64);
        let z = Complex(theta.cos(), theta.sin() * dir);
        let mut ac = Complex::one();
        let mut g = vec![];
        for i in 0..n {
            g.push(g0[i % (n / 2)] + (ac * g1[i % (n / 2)]));
            ac = z * ac;
        }
        g
    }
    fn fft(f: &Vec<Complex<f64>>) -> Vec<Complex<f64>> {
        FFT::_fft(f, 1.0)
    }
    fn defft(f: &Vec<Complex<f64>>) -> Vec<Complex<f64>> {
        FFT::_fft(f, -1.0)
    }

    fn _convolution(x: &Vec<Complex<f64>>, y: &Vec<Complex<f64>>) -> Vec<Complex<f64>> {
        let m = x.len();
        let xh = FFT::fft(&x);
        let yh = FFT::fft(&y);
        let xyh = (0..m).map(|i| (xh[i] * yh[i]) * (1.0 / m as f64)).collect();
        FFT::defft(&xyh)
    }

    pub fn convolution(f: &Vec<f64>, g: &Vec<f64>) -> Vec<f64> {
        assert!(f[0] == 0.0);
        assert!(g[0] == 0.0);
        let n = std::cmp::max(f.len(), g.len());
        let mut m = 1;
        while m < n {
            m <<= 1
        }
        m <<= 1; // length should be 2**pow
        let x: Vec<Complex<f64>> = f
            .iter()
            .map(|&k| Complex(k, 0.0))
            .chain((0..m - f.len()).map(|_| Complex(0.0, 0.0)))
            .collect();
        let y: Vec<Complex<f64>> = g
            .iter()
            .map(|&k| Complex(k, 0.0))
            .chain((0..m - g.len()).map(|_| Complex(0.0, 0.0)))
            .collect();

        let z = FFT::_convolution(&x, &y);
        z.iter().map(|c| c.0).collect()
    }
}

#[cfg(test)]
mod test_fft {
    use crate::algorithm::fft::*;

    #[test]
    fn it_works() {
        let a = vec![0.0, 1.0, 2.0, 3.0, 4.0];
        let b = vec![0.0, 1.0, 2.0, 4.0, 8.0];
        let expected = vec![0.0, 0.0, 1.0, 4.0, 11.0, 26.0, 36.0, 40.0, 32.0];
        let c = FFT::convolution(&a, &b);
        for k in 0..expected.len() {
            assert!((expected[k] - c[k]).abs() < 1e-6);
        }
    }
}