[2405.04517] xLSTM: Extended Long Short-Term Memory
概要
Long Short-Term Memory (LSTM) の拡張版 xLSTM を提案する. ベンチマークで Transformer に肉薄する結果を出せた.
変更は3つ
- 指数ゲーティング (exponential gating)
-
メモリ構造の変更
- スカラーメモリの sLSTM
- 行列メモリの mLSTM
-
以上を残渣ブロックに統合した
LSTM
ベース
LSTM の初期バージョンは 1991 年には提案されていた. 仔細を省いて概略を描くと, 入力列 \(z_t\) を受け取って隠れ状態の列 \(h_t\) を次のような漸化式で求める.
\[c_t = f_t c_{t-1} + i_t z_t\]
\[h_t = o_t c_t\]
ここで \(f, i, o\) がゲートと呼称されるもので, それぞれ forget, input, output を表現している.
多くの分野でLSTMは成功したが, 新たに出現した Transformer に比べるとさすがに弱い.
sLSTM
2つ新しいポイントがあって,
- normalizer state の導入
- input gate 及び forget gate に exp を使う
入力列 \(x_t\) について
-
state
-
cell state
- \(c_t = f_t c_{t-1} + i_t z_t\)
-
normalizer state
- \(n_t = f_t n_{t-1} + i_t\)
-
hidden state
- \(h_t = o_t( c_t / n_t )\)
-
cell input
-
\(z_t = \phi(\tilde{z_t})\)
- \(\tilde{z_t} = w_z^\top x_t + r_z h_{t-1} + b_z\)
-
gates
-
input gate
-
\(i_t = \exp(\tilde{i_t})\)
- \(\tilde{i_t} = w_i^\top x_t + r_i h_{t-1} + b_i\)
-
forget gate
-
\(f_t = \exp(\tilde{f_t})\) または \(f_t = \sigma(\tilde{f_t})\)
- \(\tilde{f_t} = w_f^\top x_t + r_f h_{t-1} + b_f\)
-
output gate
-
\(o_t = \sigma(\tilde{o_t})\)
- \(\tilde{o_t} = w_o^\top x_t + r_o h_{t-1} + b_o\)
さらに Milakov & Gimelshein, 2018 で提案された stabilizer テクニックがある. ただし exp にしてるのは本論文の新規性.
-
stabilizer state
- \(m_t = \max(\log f_t + m_{t-1}, \log i_t)\)
-
stabilized input gate
- \(i'_t = \exp(\tilde{i_t} - m_t)\)
-
stabilized forget gate
- \(f'_t = \exp( \log f_t + m_{t-1} - m_t)\)
これで出来た \(i', f'\) で \(i,f\) を置き換えるというもの. exp すると値が大きくなりすぎて数値計算上オーバーフローしうるのでこれを使う.
mLSTM
\(\def\R{\mathbb{R}}\) LSTM のスカラーメモリを \(c \in \R\) から行列 \(C \in \R^{d \times d}\) に拡張する. Transformer でいうところの key/value を使うため.
-
state
-
cell state
- \(C_t = f_t C_{t-1} + i_t (v_t k_t^\top)\)
-
normalizer state
- \(n_t = f_t n_{t-1} + i_t k_t\)
-
hidden state
-
\(h_t = o_t \odot \tilde{h_t}\)
- \(\tilde{h_t} = C_t q_t / \max(1, n_t^\top q_t)\)
-
input
-
query input
-
key input
- \(k_t = \frac{1}{\sqrt{d}} W_k x_t + b_k\)
-
value input
-
gate
-
input gate
-
\(i_t = \exp(\tilde{i_t})\)
- \(\tilde{i_t} = w_i^\top x_t + b_i\)
-
forget gate
-
\(f_t = \exp(\tilde{f_t})\) または \(\sigma(\tilde{f_t})\)
- \(\tilde{f_t} = w_f^\top x_t + b_f\)
-
output gate
-
\(o_t = \sigma(\tilde{o_t})\)
- \(\tilde{o_t} = W_o x_t + b_o\)
xLSTM
sLSTM または mLSTM を組み込んだブロックを残渣ブロックとして使う.
実験
sLSTM ブロックを a 個, mLSTM を b 個使ったものを xLSTM[a:b]
と記述する.
結果
LSTM を数十億のパラメータにスケールアップした結果 「Transformers や State Space Models と同程度に良い」といえる. スケーリング法則によれば, より大きな xLSTM モデルは現在の Transformer ベースの言語モデルの本格的な競合となる可能性がある.