DeepLearningを勉強する人

興味のあることを書く

Pointer Networks

[1506.03134] Pointer Networks

論文まとめ

入力系列上のインデックスに対応した要素から成る出力系列の条件付き確率分布を学習するアーキテクチャ.
この種の問題は、出力の各ステップでのターゲットクラスの数が、可変である入力長にいぞんしているので、Seq2SeqやNeural Turing Machineでは解くのは難しい.
(例えば、可変長のソート問題や種々の組み合わせ最適化問題など.)

既存のAttention

  • デコーダの各ステップで、エンコーダの隠れユニットをコンテキストに混ぜる

今回のAttention

  • 出力として、入力系列中の要素を指すポインタ(つまりインデックス)として利用する

Sequence-to-Sequence

\mathcal P = \{ P_1, ..., P_n \} : 入力系列
\mathcal C^{\mathcal P} = \{C_1, ..., C_{m( \mathcal P )} \},  m( \mathcal P ) \in [1, n ] : 出力系列(各出力は入力系列上のインデックス)

トレーニングデータのペア(\mathcal P, \mathcal C^{\mathcal P})が与えられた時、次の条件付き確率を、RNN (LSTM)によるパラメトリックモデルで推定するというもの.
{\displaystyle
p(\mathcal C^{\mathcal P} | \mathcal P ; \theta) = \prod_{i=1}^{m(\mathcal P)} p_{\theta} (C_i | C1, ..., C_{i-1}, \mathcal P; \theta)
}
以下のように、トレーニングセットの条件付き確率を最大化するように学習する.
{\displaystyle
\newcommand{\argmax}{\mathop{\rm arg~max}\limits}

\theta^* = \argmax_{\theta} \sum_{\mathcal P,  \mathcal C^{\mathcal P}} \log {p(\mathcal C^{\mathcal P} | \mathcal P ; \theta) }
}

統計的な独立性は仮定はせずに、RNNは各時刻iにおいて、P_iを入力系列の終わり(\Rightarrow)まで受け取り、出力系列の終わり(\Leftarrow)まで出力記号C_jを生成する.

推定時には、入力系列\mathcal Pが与えられ、学習済みパラメータ\theta^*を用いて、条件付き確率を最大化する列、つまり\newcommand{\argmax}{\mathop{\rm arg~max}\limits} \hat{\mathcal C}^{\mathcal P} = \argmax_{\mathcal C ^ {\mathcal P}} p ( \mathcal C ^ {\mathcal P} | \mathcal P; \theta^*) となる\hat{\mathcal C}^{\mathcal P}を選択する.

最適な系列\hat {\mathcal C}を見つけるには、計算量的に困難なので、ビームサーチを行ったりする.

このsequence-to-sequenceにおいては、出力は入力系列上のインデックスから選択されるのであるから、すべての記号C_iの数は入力列長であるnに固定されている.
つまり異なるnごとに学習しなくてはならない. 

ちなみに、出力の数がO(n)だとしたら、計算量はO(n)となる.

Content Based Input Attention

attentionというものを考えて、seq2seqでは固定的であったdecoderのステートに対してより多くの情報を付加する.

(e_1, ..., e_n): encoder hidden states
(d_1, ..., d_{m(\mathcal P)}): decoder hidden states
としたとき、attentionを以下のように定義する.

{ \displaystyle
\begin{eqnarray}
u_j^i & = &v^T \tanh(W_1e_j + W_2d_i) & j \in (1, ..., n) \\
a_j^i & = &{\rm softmax}(u_j^i) & j \in (1, ..., n) \\
d_i^{'} & = & \sum_{j=1}^n a_j^i e_j  \\
{\rm hidden \ states} & = & {\rm concat}(d_i^{'}, d_i)
\end{eqnarray}
}

 v, W_1, W_2は学習パラメータで、 {\rm softmax}{\bf u}^iを正規化してattention maskを生成する.
計算量的には推定時に、各出力でn命令処理するので、O(n^2)となる.

Seq2Seqより性能は良いが、やはり出力の辞書サイズが入力に依存するような問題には適用できない.

Ptr-Net

Seq2Seqでは、条件付き確率p(C_i | C_1, ..., C_{i-1}, \mathcal P)を計算するために、固定サイズの入力の辞書(系列上のインデックス)における{\rm softmax}の分布を使用する形になる.
したがって、出力が入力の辞書サイズに依存するような場合には適用できなかった. これに対して、以下のように、attentionを用いて条件付き確率p(C_i | C1, ..., C_{i-1}, \mathcal P)モデリングする.

{\displaystyle
\begin {eqnarray}
u_j^i & = & v^T \tanh(W_1e_j + W_2d_i) & \ j \in (1, ..., n) \\
p(C_i | C_1, ..., C_{i-1}, \mathcal P) & = & {\rm softmax}(u^i)
\end {eqnarray}
}

 v, W_1, W_2は学習パラメータで、 {\rm softmax}{\bf u}^iを正規化して、入力の辞書上(インデックス)の確率分布を生成する.
通常のattentionのように、「より多くの情報を伝播させるために、encoderのステートe_jを混ぜる」というようなことはせずに、u_j^iを入力要素へのポインタとして使う.
また、 C_{i-1}で条件付けするために、対応するP_{C_{i-1}}を入力としてコピーする.

図的には以下のような感じ
f:id:wanwannodao:20170612195610p:plain
(論文より引用)

データセット

Convex Hull (凸包)

定義
凸包 - Wikipedia


こんな感じのデータ

0.12597930 0.57132358 0.77404416 0.01266053 0.69612552 0.98888917 0.56750540 0.30860638 0.25714026 0.99675915 0.41245506 0.03328769 0.99328556 0.97091931 0.82174988 0.08516088 0.63969443 0.51914056 0.45612945 0.54733761 0.32766033 0.43352998 0.49206557 0.89107185 0.13685374 0.00708945 0.61040137 0.43254429 0.88256464 0.81985257 0.07880500 0.53008275 0.42095766 0.92055700 0.02109736 0.33024543 0.76352942 0.73969747 0.08505665 0.51877561 0.62335861 0.39605697 0.86642364 0.09540971 0.89609816 0.87439433 0.57799306 0.40433588 0.01053175 0.77368518 0.49862115 0.26769698 0.94832038 0.56638474 0.03807545 0.71314326 0.97767538 0.72042601 0.82861561 0.41455754 0.56748456 0.32859033 0.87639463 0.93457765 0.28872692 0.14781993 0.18529194 0.06272494 0.32126462 0.56453709 0.81442383 0.01964365 0.56290155 0.64332693 0.93979231 0.16170123 0.36700478 0.97992791 0.26060579 0.12514376 0.33918180 0.76817253 0.41583231 0.49321529 0.41187788 0.27050384 0.19393678 0.46065066 0.29478536 0.77983912 0.54389681 0.26205415 0.20471867 0.04153719 0.95478980 0.33200250 0.04618655 0.56410345 0.99342056 0.87685348 output 2 36 38 48 50 7 3 5 25 18 13 2

可視化するとこのような感じ

https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/data.png?raw=true

実装

github.com

Decoder

dec_cell_in = utils.fc(dec_inputs, dec_inputs.get_shape()[-1],
    lstm_width, init_w=initializer, a_fn=tf.nn.elu)

(dec_cell_out, dec_state) = dec_cell(dec_cell_in, dec_state)

# W1, W2 are square matrixes (SxS)
# where S is the size of hidden states
W1 = tf.get_variable("W1", [lstm_width, lstm_width], dtype=tf.float32, initializer=initializer)
W2 = tf.get_variable("W2", [lstm_width, lstm_width], dtype=tf.float32, initializer=initializer)
# v is a vector (S)
v  = tf.get_variable("v", [lstm_width], dtype=tf.float32, initializer=initializer)

# W2 (SxS) d_i (S) = W2d (S)
W2d = tf.matmul(dec_state.h, W2)
# u_i (n)
u_i = []
                
for j in range(num_steps):
  # W1 (SxS) e_j (S) = W1e (S)
  # t = tanh(W1e + W2d) (S)
  t    = tf.tanh( tf.matmul(enc_states[j].h, W1) + W2d )
  # v^T (S) t (S) = U_ij (1)  
  u_ij = tf.reduce_sum(v*t, axis=1) # cuz t is acutually BxS

  u_i.append(u_ij)

u_i   = tf.stack(u_i, axis=1) # asarray
probs = tf.nn.softmax(u_i)

Loss

L2 Loss

self.loss = tf.nn.l2_loss(targets - self.C_prob)

入力データ

上の方で貼った図に倣い、入力の先頭(index 0)を終了記号に対応させるようにした.
具体的にどのような値を設定するべきなのかわからなかったので入力系列のインデックス0には[-1.0, -1.0]を対応させた.

トレーニングデータの出力系列はzero-padidngした。

enc.insert(0, '-1.0')
enc.insert(0, '-1.0')

while len(dec) != len(enc) // 2:
  dec = np.append(dec, _STOP)

結果

https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/loss.png?raw=true

ある程度それっぽい出力をしている例
https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/eval_6.png?raw=true
[ 2 11 21 21 39 28 40 50 47 2 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/eval_8.png?raw=true
[ 1 18 42 42 31 25 25 49 26 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/eval_18.png?raw=true
[ 2 2 44 12 30 36 36 21 29 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/eval_63.png?raw=true
[ 2 2 7 26 6 6 44 4 27 27 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/eval_75.png?raw=true
[ 5 34 34 32 8 43 43 45 18 25 5 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
https://github.com/Wanwannodao/DeepLearning/blob/master/RNN/PtrNet/eval_84.png?raw=true
[ 6 31 28 41 41 48 39 9 9 13 6 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]

微妙なものも多くあったので、実装に不備があるのかもしれない...