Pointer Networks
論文まとめ
入力系列上のインデックスに対応した要素から成る出力系列の条件付き確率分布を学習するアーキテクチャ.
この種の問題は、出力の各ステップでのターゲットクラスの数が、可変である入力長にいぞんしているので、Seq2SeqやNeural Turing Machineでは解くのは難しい.
(例えば、可変長のソート問題や種々の組み合わせ最適化問題など.)
既存のAttention
- デコーダの各ステップで、エンコーダの隠れユニットをコンテキストに混ぜる
今回のAttention
- 出力として、入力系列中の要素を指すポインタ(つまりインデックス)として利用する
Sequence-to-Sequence
: 入力系列
] : 出力系列(各出力は入力系列上のインデックス)
トレーニングデータのペアが与えられた時、次の条件付き確率を、RNN (LSTM)によるパラメトリックモデルで推定するというもの.
以下のように、トレーニングセットの条件付き確率を最大化するように学習する.
統計的な独立性は仮定はせずに、RNNは各時刻において、を入力系列の終わり()まで受け取り、出力系列の終わり()まで出力記号を生成する.
推定時には、入力系列が与えられ、学習済みパラメータを用いて、条件付き確率を最大化する列、つまりとなるを選択する.
最適な系列を見つけるには、計算量的に困難なので、ビームサーチを行ったりする.
このsequence-to-sequenceにおいては、出力は入力系列上のインデックスから選択されるのであるから、すべての記号の数は入力列長であるに固定されている.
つまり異なるごとに学習しなくてはならない.
ちなみに、出力の数がだとしたら、計算量はとなる.
Content Based Input Attention
attentionというものを考えて、seq2seqでは固定的であったdecoderのステートに対してより多くの情報を付加する.
: encoder hidden states
: decoder hidden states
としたとき、attentionを以下のように定義する.
は学習パラメータで、はを正規化してattention maskを生成する.
計算量的には推定時に、各出力で命令処理するので、となる.
Seq2Seqより性能は良いが、やはり出力の辞書サイズが入力に依存するような問題には適用できない.
Ptr-Net
Seq2Seqでは、条件付き確率を計算するために、固定サイズの入力の辞書(系列上のインデックス)におけるの分布を使用する形になる.
したがって、出力が入力の辞書サイズに依存するような場合には適用できなかった. これに対して、以下のように、attentionを用いて条件付き確率をモデリングする.
は学習パラメータで、はを正規化して、入力の辞書上(インデックス)の確率分布を生成する.
通常のattentionのように、「より多くの情報を伝播させるために、encoderのステートを混ぜる」というようなことはせずに、を入力要素へのポインタとして使う.
また、で条件付けするために、対応するを入力としてコピーする.
図的には以下のような感じ
(論文より引用)
データセット
Convex Hull (凸包)
に
こんな感じのデータ
0.12597930 0.57132358 0.77404416 0.01266053 0.69612552 0.98888917 0.56750540 0.30860638 0.25714026 0.99675915 0.41245506 0.03328769 0.99328556 0.97091931 0.82174988 0.08516088 0.63969443 0.51914056 0.45612945 0.54733761 0.32766033 0.43352998 0.49206557 0.89107185 0.13685374 0.00708945 0.61040137 0.43254429 0.88256464 0.81985257 0.07880500 0.53008275 0.42095766 0.92055700 0.02109736 0.33024543 0.76352942 0.73969747 0.08505665 0.51877561 0.62335861 0.39605697 0.86642364 0.09540971 0.89609816 0.87439433 0.57799306 0.40433588 0.01053175 0.77368518 0.49862115 0.26769698 0.94832038 0.56638474 0.03807545 0.71314326 0.97767538 0.72042601 0.82861561 0.41455754 0.56748456 0.32859033 0.87639463 0.93457765 0.28872692 0.14781993 0.18529194 0.06272494 0.32126462 0.56453709 0.81442383 0.01964365 0.56290155 0.64332693 0.93979231 0.16170123 0.36700478 0.97992791 0.26060579 0.12514376 0.33918180 0.76817253 0.41583231 0.49321529 0.41187788 0.27050384 0.19393678 0.46065066 0.29478536 0.77983912 0.54389681 0.26205415 0.20471867 0.04153719 0.95478980 0.33200250 0.04618655 0.56410345 0.99342056 0.87685348 output 2 36 38 48 50 7 3 5 25 18 13 2
可視化するとこのような感じ
実装
Decoder
dec_cell_in = utils.fc(dec_inputs, dec_inputs.get_shape()[-1], lstm_width, init_w=initializer, a_fn=tf.nn.elu) (dec_cell_out, dec_state) = dec_cell(dec_cell_in, dec_state) # W1, W2 are square matrixes (SxS) # where S is the size of hidden states W1 = tf.get_variable("W1", [lstm_width, lstm_width], dtype=tf.float32, initializer=initializer) W2 = tf.get_variable("W2", [lstm_width, lstm_width], dtype=tf.float32, initializer=initializer) # v is a vector (S) v = tf.get_variable("v", [lstm_width], dtype=tf.float32, initializer=initializer) # W2 (SxS) d_i (S) = W2d (S) W2d = tf.matmul(dec_state.h, W2) # u_i (n) u_i = [] for j in range(num_steps): # W1 (SxS) e_j (S) = W1e (S) # t = tanh(W1e + W2d) (S) t = tf.tanh( tf.matmul(enc_states[j].h, W1) + W2d ) # v^T (S) t (S) = U_ij (1) u_ij = tf.reduce_sum(v*t, axis=1) # cuz t is acutually BxS u_i.append(u_ij) u_i = tf.stack(u_i, axis=1) # asarray probs = tf.nn.softmax(u_i)
Loss
L2 Loss
self.loss = tf.nn.l2_loss(targets - self.C_prob)
入力データ
上の方で貼った図に倣い、入力の先頭(index 0)を終了記号に対応させるようにした.
具体的にどのような値を設定するべきなのかわからなかったので入力系列のインデックス0には[-1.0, -1.0]を対応させた.
トレーニングデータの出力系列はzero-padidngした。
enc.insert(0, '-1.0') enc.insert(0, '-1.0') while len(dec) != len(enc) // 2: dec = np.append(dec, _STOP)
結果
ある程度それっぽい出力をしている例
[ 2 11 21 21 39 28 40 50 47 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[ 1 18 42 42 31 25 25 49 26 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[ 2 2 44 12 30 36 36 21 29 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[ 2 2 7 26 6 6 44 4 27 27 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[ 5 34 34 32 8 43 43 45 18 25 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
[ 6 31 28 41 41 48 39 9 9 13 6 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
微妙なものも多くあったので、実装に不備があるのかもしれない...