最適輸送問題, Sinkhorn 距離, ソート

2021-06-03 (Thu.)

最適化

参考文献

この記事の内容は全て次のいずれかに書かれています.

  1. 輸送問題を近似的に行列計算で解く(機械学習への応用つき) - 私と理論
  2. 輸送問題の解を微分する - Qiita
  3. ソートの微分とソートの一般化 - Qiita

概要

最適輸送問題と呼ばれる最適化問題は Sinkhorn-Knopp アルゴリズムによって近似的に計算することが出来る. このアルゴリズムは全て線形変換で表現されるので微分可能であり, 勾配法で学習するような機械学習への応用が考えられる. この問題の応用例として数列のソートを紹介する.

私が理解できてない都合上, なぜ上手くいくのかといった理論部分には触れない.

最適輸送問題

ある製品の需要と供給について考える. \(N\) 箇所の工場があり, \(i\) 番目の工場はその製品を \(a_i\) だけ作る. 一方で \(M\) 箇所の小売店があり, \(j\) 番目の小売店は製品を \(b_i\) だけ要求する. \[\sum_i a_i = \sum_j b_j\] であるとき, 製品を各工場から各小売店に直接輸送することで供給を満たしたい. すなわち, \(i\) 番目の工場から \(j\) 番目の小売店には製品を \(P_{ij}\) だけ運ぶとすれば,

と表現できる. ただしここで各輸送方法にはコストが設定されており, \(i\) 番目の工場から \(j\) 番目の小売店への輸送には量当たり \(C_{ij}\) のコストが掛かる. したがって輸送方法 \(P\) に対して全体として \[\sum_i \sum_j C_{ij} P_{ij}\] のコストが掛かる. これを最小化しようというのが 最適輸送問題 である.

ただし、本記事では \(N = M = n\) であるとする.

形式的表現

文章で問題を表現したが列ベクトルと行列表示すれば最適輸送問題は次で表現される. ただしここで \(\mathbb R_+\) は非負整数全体とする.

ちなみに制約の表現だが, 全ての成分が \(1\) の列ベクトル \(1_n \in \{1\}^n\) を用いると次のように書き直せる:

最適輸送距離

2つのベクトル \(a, b \in \mathbb R_+^n\) と暗に与えられるコスト行列 \(C\) に対して最適輸送問題を解いた結果の \[\min_P L(P)\]\(a\)\(b\) の距離 \[d(a, b; C) = \min_P L(P)\] と定義することができる. これを 最適輸送距離 と呼ぶ. これはいわゆる距離の公理を満たす.

この記事では特に \(a,b\) の次元を同じに揃えているが本来この2つは異なっていてもよくて, 問題なく定義される.

特にベクトル \(a, b\) を離散確率分布だとしたとき, この距離は Wasserstein 距離 と呼ばれ, 確率分布どうしの距離として定められる. また Earth-Mover (EM) 距離 と呼ばれる値もこれから定義される.

Sinkhorn-Knopp アルゴリズム, Sinkhorn 距離

最適輸送問題における目的関数を次のように変更する.

\[L(P) = \sum_i \sum_j \left[ C_{ij} P_{ij} + \frac{1}{\lambda} P_{ij} \log P_{ij} \right]\]

これはもとの目的関数から \(P\) に関するエントロピーを 減算 した形になっていて, エントロピー正則化 などとも呼ばれる. 最適化問題の文脈で言えば 緩和 だとも言える. この \(L(P)\) の最小値のことを Sinkhorn 距離 と呼ぶ.

ここで \(\lambda\) は正の定数だとする. \(\lambda \to +\infty\) のときに最適輸送問題と一致する. 十分大きな \(\lambda\) を与えることで Sinkhorn 距離で最適輸送距離を近似することができる.

さてエントロピー正則化された方の最適化問題は次に示す Sinkhorn-Knopp アルゴリズムで解くことができる.

Sinkhorn-Knopp アルゴリズム

  1. 行列 \(K\) を次で定める
  2. \(u \in \mathbb R^n\) をランダムに決める
  3. 収束するまで次を順次繰り返す
    1. \(v \leftarrow b ~/~ (K^\top u)\)
    2. \(u \leftarrow b ~/~ (K v)\)
      • ここで要素同士の除算を \(/\) と書いた
  4. \(u, v\) それぞれを対角に置いた \(n\times n\) の対角行列を \(U,V\) と置く
  5. 次の \(P\) が最適な輸送方法

計算コストが重たいのは 3.1, 3.2 における行列とベクトルの乗算, 5 の対角行列と行列の乗算であって, これらは計算量 \(O(n^2)\) で計算できる.

ソートへの応用

数列のソート(整列)は最適輸送問題で解くことができる.

ソートへの帰着

ソートしたい長さ \(n\) の数列 \(x \in \mathbb R^n\) が与えられたとする. ただしここで, 数列 \(x\) と列ベクトル \(x\) とは自然に同一視する. \(( \left( x_1, x_2, \ldots, x_n \right) \iff \left[ x_1, x_2, \ldots, x_n \right]^\top )\)

これに対して, 以下の \((y, a, b, C)\) を用意する.

ここで \((a,b,C)\) で定められる最適輸送問題を解いて, 最適な輸送方法が \(P \in \mathbb R^{n \times n}\) であったとする. このとき

とすると, \(S(x)\)\(x\) をソートした数列になっており, \(R(x)\)\(x\) のランクを表した数列になっている. つまり \(S(x)\) の第 \(i\) 成分は \(x\)\(i\) 番目に小さい数であって, \(R(x)\) の第 \(i\) 成分は \(x_i\)\(x\) で何番目に小さい数であるかの整数値になっている.

\(R(x)\) の表す ランク は 1-start の整数になっているが, これは \(r\) の並び替えをしてるに過ぎないので, \(r\) を変えれば 0-start の数字にする等できる.

数列 \(x = (4, 6, 2)\) について次を与える \((n=3)\).

このとき最適輸送は

であると求まる. これによって \(S(x), R(x)\) は次の通り計算される.

見ての通り \(P^\top\) 及び \(P\) はちょうど swap をするだけの行列操作になっている.