Fri Oct 25 2019

分類, 生成に求めるもの

学習データやテストデータは十分多くてすでに真の分布を表すものだとしよう. だとすれば分類の結果の分布は真の分布に近くあってほしい. でないと真の分類をしているという感じがしない.

例えば, 正例と負例とが 9:1 な偏ったデータだとする. 偏ったデータをうまく学習するテクニックは色々あるがそれはおいておいて。 テストデータでもやはり 9:1 だけ混じっているのだから, それを分類してみた結果はやはり 9:1 あることが望ましい. もちろん accuracy だけを見ると下手に 9:1 であるように矯正すると下がり得るけど, こちらのほうが真の分類をしてる感じがある.

生成についても同様. MNIST はだいたい "0" から "9" までが同じ枚数だけがデータにあるようになっている. だから 1:1:...:1 が真の分布だと言える. 素のGAN で問題になるのはいわゆる mode collapse ですなわち, 生成しやすい一文字 (私の経験上それはだいたい "8" であることが多い) だけを生成するようになってしまう問題がある. 分類のときと同様にやはり 1:1:...:1 で生成されるのが真の生成という感じがする.

GAN のバリエーションを担保させるテクニックはいくらでもあって, バッチの単位で計算をするのだから, そのバッチの中のエントロピーやなんかを大きくさせるようなロスを追加するとか, Discriminator にエントロピーを渡してしまう(mode collapse に陥ったときのエントロピーは特徴的だから, 生成データであることが見抜きやすくなる)というのも聞いたことがある.

私も一つ思いついた. Discriminator とは別に(あるいは併用してもいいかもしれないが)Label Classifier を後ろにくっつける. Label Classifier は生成されたデータのラベルを推定し, それは \(p = (p_i)_{i=1,2,\ldots,10}\) という10次元実ベクトルの形をしている. これをバッチの中で各生成データについて推定すれば, バッチのサイズを \(B\) とすれば \(B\) 個の実ベクトル \[p^1, p^2, \ldots, p^B\] が手に入る. これの平均を取る. \[\frac{1}{B} \sum_i p^i\] 先程言った真の分布に近いならこのベクトルは \[(0.1, 0.1, \ldots, 0.1)\] であるはずだ. というわけでこのベクトルどうしの KL 距離をロスに加える. 生成をしてから分類して結果の平均を取るまでの操作は全て微分可能.