損失関数をいい感じにいじってより学習に易しい損失関数を作る.
損失関数に基づく機械学習というのは次のようなものである.
真のデータ分布 \(D\) の上で観測される \((X,Y)\) の分布について \(L\) の期待値
\[L_D(w) := \mathbb E \left[ \ell(w,x,y) \right]\]を最小化するような \(w\) を探索することを学習と呼ぶ.
さて真のデータ分布 \(D\) は分からないので実際にはサンプリング
\[S = \{ (x_i, y_i) \}_{i=1,\ldots,N}\]の上で推定した
\[L_S(w) := \frac{1}{N} \sum_S \ell(w,x,y)\]を変わりの近似値として使う.
ここで最急降下法によって最適な \(w\) を得る場合には勾配が重要になる. 特に最適解付近では勾配が緩やかな方が良いとされている. SAM は損失関数自体を少しいじることで, そんな良い勾配を持った関数にすることを目指す.
任意の正数 \(\rho > 0\) に対して, ある狭義単調増加関数 \(h \colon \mathbb R_+ \to \mathbb R_+\) によって次が成り立つ (with highly probability):
\[L_D(w) \leq \max_{\|\epsilon\|_2 \leq \rho} L_S(w + \epsilon) + h(\|w\|_2^2 / \rho^2).\]Theorem 1 の右辺の \(L_D\) に対する上限を参考に, 損失関数とする.
\[L_S^{\mathrm{SAM}}(w) := \max_{\| \epsilon \| \leq \rho} L(w + \epsilon) + \lambda \|w\|_2^2\]さてこの max を効率的に近似的に計算する. \(L(w + \epsilon)\) を \(\epsilon \sim 0\) の周辺で一次までテイラー展開して
\[L(w) + \epsilon^\top \nabla_w L(w) + \lambda \|w\|_2^2\]を代わりに考えることにすると, \(\epsilon\) に関わる
\[\max_\epsilon ~~ \epsilon^\top \nabla_w L(w)\]だけすればいいことになる.
で、謎の理屈を使うと, 次の値を \(\epsilon\) として使えば良い.
\[\hat{\epsilon} = \rho ~ \mathrm{sign}(\nabla_w L_S(w)) \frac{ \left| \nabla_w L_S(w) \right| }{ \sqrt{ \| \nabla_w L_S(w) \|_2^2 } }\]ここで分子にある \(\left| a \right|\) は要素ごとに絶対値をとったもの.