[2011.10566] Exploring Simple Siamese Representation Learning

\[ \def\N{\mathbb{N}} \def\R#1{\mathbb{R}^{#1}} \def\Pr{\mathop{\mathrm{Pr}}} \]

Intro

Siamese ネットワークの方法は画像認識の分野では常識化した. これは正例としての2つの画像ペアが同じ(または似てる)というものと, 負例として2つの画像ペアが違うものを表しているというペアのデータセットを学習するものだった.

ここで紹介される Simple Siamese (以下 SimSiam) は負例ペアを必要としない. 正例ペアも同一の画像に異なるオーグメントを施すことで自動的に作るのでペアのデータセットは完全に不要.

こうなるとすべての画像について同一のベクトルを出すような崩壊 (collapsing) 現象が起きるのが普通だが, 簡単なテクニックでこれを防げた. 特に勾配を切る操作 (detach, stop-grad) がとても重要だった.

Method

Figure 1

全く同一のデータ \(x\) に異なる (ランダムな) オーグメントを施して \(x_1,x_2\) を作る. それぞれに encoder を施してこの出力が最終的に求めている画像表現ベクトルになる.

しかし工夫として, 一方にのみ predictor \(h\) を適用し, もう一方は勾配を切る. 図では左に \(h\) を適用したバージョンだけを書いているが, 実際には左右対称に, 右に \(h\) を適用して左の勾配を切るバージョンも当然考えられる. 実装ではその2つを計算してロスを足し合わせている.

def D(p, z):
    """損失関数

    コサイン類似度のマイナスを採用している
    ここで z の方は勾配を切っていることに注意
    """
    z = z.detach()  # stop gradient
    return -cosine(p, z)

# 学習可能なエンコーダー層
f = Encoder()

# 学習可能な予測層
h = Predictor()

# 学習ループ
for x in minibatches:
    x1, x2 = random_augment(x), random_augment(x)  # 独立にランダムなオーグメンテーションを施す
    z1, z2 = f(x1), f(x2)
    p1, p2 = h(z1), h(z2)
    loss1 = D(p1, z2)
    loss2 = D(p2, z1)
    loss = (loss1 + loss2) / 2
    loss.backward()
    optimizer.update()

エンコードで得た z と, そこに更に h を挟んで得た p とを比較する. この比較では p 側に使った方しかアップデートしない.

上のアルゴリズムでは, 一方に h を適用するのを両方実行して両方のロスを使ってる.

両方ともにアップデートすると, collapsing (同じ値に潰れる) 問題が起きるのでこれを少し回避できる. 加えて予測層 h を一方にだけ適用する. これで完全に collapsing が起きなくなる. h は浅い MLP 層で良い

実験

勾配を切る効果

この操作を無くすと損失はあっという間に減って, コサイン類似度が \(1\) になる. すなわち全ての画像について1つの値に潰れる (collapsing) ことがわかる.

ちなみにガウス分布に独立に従う \(d\) 次元ベクトル \(z\) を標準化したベクトル \(z / \|z\|_2\) のその標準偏差は \(1/ \sqrt{d}\) になる. エンコーダ z による出力について, その標準偏差を考えると大体そのくらいの値になるのが正しい学習に期待される挙動だ. これも一緒に調べると, 勾配を切ることで大体この値より少し小さいくらいを推移するが, 切らないとこれもすぐにゼロになる. 分散がゼロということなのでやはり一つの値に潰れてることが確認できた.

予測層 (Predictor) の効果

学習可能な予測層を置いておくことはめちゃくちゃ大事. パラメータ固定の MLP を置いておいても意味はなく, 学習が収束しなかった(しかしこれによって collapsing が起きたわけではない). 学習において lr decay は不要で, 学習率は固定でよかった.

バッチサイズについて

collapsing を防ぐ目的でバッチサイズを大きくするのがこれまでの一般的なテクだった. 今回の実験では精度への影響はややあったが, どちらにせよ collapsing は他の工夫で防げたのでバッチサイズは重要じゃなかった.

仮説: SimSiam は EM アルゴリズムになってる

勾配を止める操作がまさに交互に一方を学習してく EM アルゴリズム的になってそう.

画像 \(x\) , これをランダムにオーグメンテーションした \(T(x)\) , これをエンコードして得る \(F(T(x); \theta)\) があって, \(x\) に対して適切な表現ベクトル \(\eta_x\) になるように学習する.

\[\mathcal{L}(\theta, \eta) = \mathbb E \left[ \| F(T(x); \theta) - \eta_x \|^2_2 \right]\]

注意点として \(F\) のパラメータ \(\theta\) だけでなく, 表現ベクトル \(\eta_x\) そのものも学習すべきパラメータであるところ. この2つをアップデートするのに, 一方を固定して他方をアップデートすることを交互にやる. これが EM アルゴリズム的だといっている.

\(\eta_x\) を固定するのは普通に \(F(T(x); \theta)\) に関してだけアップデートすればいい.

\(\theta\) を固定して \(\eta_x\) を学習することを考えるが, これは解析的には次が最適解.

\[\eta_x \leftarrow \mathbb{E} \left[ F(T(x); \theta) \right]\]

そうすればいいのだが, 期待値を計算するのが実質無理なので, 直前の \(T', \theta'\) を使って

\[\eta_x = F(T'(x); \theta)\]

としておく( \(T\) はランダムなので \(T'\) と書いて区別しておく).

とすると,

\[\mathcal{L}(\theta, \eta) = \mathbb E \left[ \| F(T(x); \theta) - F(T'(x); \theta') \|^2_2 \right]\]

だし, パラメータ \(\theta\) の更新は

\[\theta \leftarrow \mathop{\mathrm{argmin}}_\theta ~ \mathbb E \left[ \| F(T(x); \theta) - F(T'(x); \theta') \|^2_2 \right]\]

になった. ここで \(T', \theta'\) が定数なわけだが, これがまさに勾配をストップすることに相当している.