DeepLearningを勉強する人

興味のあることを書く

Wasserstein GAN (WGAN)

Wasserstein GAN (WGAN)

[1701.07875] Wasserstein GAN
([1701.04862] Towards Principled Methods for Training Generative Adversarial Networks WGANの話の前にこの話がある)
Martin Arjovsky氏の実装(Torch)
GitHub - martinarjovsky/WassersteinGAN

WGANをTensorFlowで実装した
github.com

論文まとめ

深く理解しようとするとかなり数学的知識が要求されるので、直感的な理解だけを追っていくことにする.

そもそもVAEやGANのアプローチというのは、自分の分布{ \mathbb{P}_\theta }を真の分布{\mathbb{P}_r }に近づけていくというものだが、この分布間の距離/ダイバージェンスの定義の仕方によって、学習の収束性、安定性に差異がある.
基本的な距離/ダイバージェンスの定義としては以下の4つがある.
{\mathbb{P}_r }{\mathbb{P}_g }はあるコンパクト空間{\mathcal{X}}上で定義される確率分布.(コンパクト空間は距離空間においては、有界閉集合をイメージすればよいらしい. ここでは画像の空間({\lbrack 0,1 \rbrack^d} )と思っていてよい)

  • Total Variation (TV) distance

{\displaystyle
 \delta(\mathbb{P}_r , \mathbb{P}_g) = \sup_{A \in \sum} |\mathbb{P}_r(A) - \mathbb{P}_g(A)| }
省略

  • Kullback-Leibler (KL) divergence

{\displaystyle
 KL(\mathbb{P}_r || \mathbb{P}_g) = \int \log( \frac{\mathbb{P}_r(x)}{\mathbb{P}_g(x)} ) \mathbb{P}_r(x) d\mu(x)
}
非対称、発散する可能性がある

  • Jensen-Shannon (JS) divergence

{\displaystyle
  JS(\mathbb{P}_r , \mathbb{P}_g) = KL(\mathbb{P}_r || \mathbb{P}_m) + KL(\mathbb{P}_g || \mathbb{P}_m) \\
  \mathbb{P}_m = \frac {\mathbb{P}_r + \mathbb{P}_g} {2}
}
対称、常に定義可能
GANは近似的にこれを最小化することに対応する

  • Earth-Mover (EM) distance (Wasserstein-1)

{\displaystyle
  W(\mathbb{P}_r , \mathbb{P}_g) = \inf_{\gamma \in \prod (\mathbb{P}_r , \mathbb{P}_g)}  \mathbb{E}_{(x,y) \sim \gamma} \lbrack ||x-y|| \rbrack
}

{\gamma(x,y)}の直感的理解は、{\mathbb{P}_r}{\mathbb{P}_g}に変えるために、地点{x}から地点{y}どれだけの"質量"を移動しなければならないかという量.
よってEM距離は直感的には最適な輸送コストと考えれば良い.
({\prod}は同時分布を表す)

以下JSとEMだけ考えることにする.
次の簡単な状況下での距離の変化を観測してみる.
{ \displaystyle
 Z \sim U \lbrack 0,1\rbrack} (一様分布)
{\mathbb{P}_0}: 確率分布 {(0, Z) \in \mathbb{R}^2 }
{g_{\theta}(z) = (\theta, z), \theta}: パラメータ
簡単に言うと、二直線間の距離みたいなイメージか.
すると各距離は以下の通り.

{\displaystyle
W(\mathbb{P}_r , \mathbb{P}_g) = |\theta| \\
\begin{eqnarray} 
JS(\mathbb{P}_r , \mathbb{P}_g) = \left\{ \begin{array}{}
\log 2 & if \theta \neq 0, \\
0 & if \theta = 0
\end{array} \right.
\end{eqnarray}
}

よって下図のようなグラフが描ける.
(左がEM、右がJS)
f:id:wanwannodao:20170228033758p:plain
(論文より引用)


この図からわかることは、EM距離は連続かつ有用な勾配(距離の直感に合う)JS距離は非連続かつ距離の直感に合わない

論文では、諸定理を証明してEM距離が他の尺度より有利な性質を持っていることを述べている.
どうにかEM距離をロス関数として使いたいのだが、定義通りだと難しい...

どうするか??
まず、次のような双対問題が存在する.

  • Kantorovich-Rubinstein duality

{\displaystyle
  K \cdot W(\mathbb{P}_r , \mathbb{P}_g) = \sup_{||f||_L <= K} \mathbb{E}_{x \sim \mathbb{P}_r \lbrack f(x) \rbrack} - \mathbb{E}_{x \sim \mathbb{P}_{\theta} \lbrack f(X) \rbrack}
}
{K}はある定数で、関数{f}{K}-Lipschitz連続
{K}-Lipschitz連続性の直感的な理解としては、関数の傾きがある定数{K}で抑えられるような一様連続性であると思っている.

であるから、以下のような問題を考えればよい
{ \displaystyle
 \max_{w \in \mathcal{W}} \mathbb{E}_{x \sim \mathbb{P}_r} \lbrack f_w(X) \rbrack - \mathbb{E}_{x \sim p(z)} \lbrack f_w(g_{\theta}(z) \rbrack
}

{\lbrace f_w {\rbrace}_{w \in \mathcal{W}}}はパラメータ化された関数で、GANの枠組みではDiscriminatorに対応する(論文中ではCritic)であり、{K}-Lipschitz連続とする.
もし、先程の双対問題の上限が{w \in \mathcal{W}}によって得られるならば、この最大化問題で{W(\mathbb{P}_r , \mathbb{P}_g)}の定数倍を計算することができる.
{g_{\theta}(z)}は、{\mathbb{P}_{\theta}}に従って、ある分布(ex. Gaussian)に従う確率変数{z}{\mathcal{X}}(画像の空間)にマッピングする関数で、GANのGeneratorに対応するもの. (これが連続ならば、{W(\mathbb{P}_r , \mathbb{P}_g)}も連続となることも証明されている.)
さて、{\theta}に関する勾配が以下になることも定理として示されている.
{\displaystyle
\nabla_{\theta} W(\mathbb{P}_r , \mathbb{P}_g) = -\mathbb{E}_{z 
\sim p(z)} \lbrack \nabla_{\theta} f(g_{\theta}(z)) \rbrack
}

利点
  • 学習の収束性、安定性
  • 意味のある、しかも出力画像の質と相関するロス
  • DiscriminatorとGeneratorの学習の進み具合を調整する必要がない
  • Mode Collapseがない

実装

上でいろいろ書いたが実装に際してはやることは至って簡単.
(Critic = Discriminator)

  1. Criticを何回か更新して、{W(\mathbb{P}_r , \mathbb{P}_g)}の良い近似をする
  2. {f_w}(Critic)の{K}-Lipschitz連続性を保つ
  3. Generatorを{W(\mathbb{P}_r , \mathbb{P}_g)}を最小化する方向に更新する.

{W(\mathbb{P}_r , \mathbb{P}_g)}の計算

これは、上で述べた双対問題から出てきた最大化問題を解くので、以下の確率的勾配を昇る方向に更新する.
よって実装上は、符号反転して最小化問題とする.
{ \displaystyle
 - \nabla_w \lbrack \frac{1}{m} \sum_{i=m}^{m} f_w(x^{(i)}) - \frac{1}{m} \sum_{i=m}^{m} f_w(g_{\theta}(z^{(i)})) \rbrack
}

{f_w}(Critic)の{K}-Lipschitz連続性

パラメータ空間{ \mathcal{W}}がコンパクト空間(有界閉集合)なら{f_w}(Critic)の{K}-Lipschitz連続であることを意味する(らしい).
これは、weight clippingをするだけで達成可能である.

{W(\mathbb{P}_r , \mathbb{P}_g)}を最小化する

GeneratorはEM距離を最小化するので、次の確率的勾配を降る方向に更新する.
{\displaystyle
 - \nabla_{\theta} \frac{1}{m} \sum_{i=m}^{m} f_w(g_{\theta}(z^{(i)})
}

以上をまとめると以下のような雰囲気になる.

self.gen_img = self.G() # g(z)
g_logits = self.C(self.gen_img, self.p) # f(g(z))

self.g_loss = -tf.reduce_mean(g_logits) # Generatorのロス
self.c_loss = tf.reduce_mean(-self.C(self.X, self.p, reuse=True) + g_logits) # Criticのロス

c_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
g_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5)
        
c_grads_and_vars = c_opt.compute_gradients(self.c_loss)
g_grads_and_vars = g_opt.compute_gradients(self.g_loss)

# Criticに関する勾配のみ
c_grads_and_vars = [[grad, var] for grad, var in c_grads_and_vars \
                            if grad is not None and var.name.startswith("C") ]
# Generatorに関する勾配のみ
g_grads_and_vars = [[grad, var] for grad, var in g_grads_and_vars \
                            if grad is not None and var.name.startswith("G") ]

self.c_train_op = c_opt.apply_gradients(c_grads_and_vars)
self.g_train_op = g_opt.apply_gradients(g_grads_and_vars)

# Weight Clipping
self.w_clip = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) \
                       for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="C")]
# 1 Step内
# よい近似となるまで
for n in range(5):
    batch_xs, _ = mnist.train.next_batch(64)
    batch_xs = batch_xs.reshape([-1, 28, 28, 1])

    feed = {wgan.X: batch_xs}
    # Critic更新
    _ = sess.run(wgan.c_train_op,  feed_dict=feed)
    # Weight Clipping 
    _ = sess.run(wgan.w_clip)

# Generator更新        
_ = sess.run(wgan.g_train_op)

ちなみに、DCGANの反省活かして少しはきれいな実装になったかと思う.
CriticとGeneratorのモデルは以下の通り.
(Batch Normalizationがなくてもよい結果は得られる、更にMLPでもよいとは書かれている)
ラッパーを色々用意したので割りとすっきり書けた(そんなことしなくてもTF-Slimを使えばいい...)

# ====================
# Generator
# ====================
class Generator():
    def __init__(self, batch_size):
        self.batch_size = batch_size
        
    def build_model(self, reuse):
        with tf.variable_scope("G", reuse=reuse):
            z = tf.random_uniform([self.batch_size, 100], minval=-1.0, maxval=1.0)

            fc1 = tf.nn.relu(batch_norm(fully_connected(z, 100, 1024, scope="fc1"), axes=[0]))

            fc2 = tf.nn.relu(batch_norm(fully_connected(fc1, 1024, 128*7*7, scope="fc2"), axes=[0]))
            fc2 = tf.reshape(fc2, [-1, 7, 7, 128])

            convt1 = tf.nn.relu(batch_norm(convt(fc2, kernel=[5, 5, 64, 128],
                                                 stride=[1, 2, 2, 1],
                                                 output=[self.batch_size, 14, 14, 64],
                                                 scope="convt1"), axes=[0, 1, 2]))
            convt2 = convt(convt1, kernel=[5, 5, 1, 64],
                           stride=[1, 2, 2, 1],
                           output=[self.batch_size, 28, 28, 1],
                           activation_fn=tf.nn.tanh,
                           scope="convt2")
            return convt2

    def __call__(self, reuse=False):
        return self.build_model(reuse)

# ====================
# Critic
# ====================
class Critic():
    def __init__(self, batch_size):
        self.batch_size = batch_size
        
    def build_model(self, X, p, reuse=False):
        with tf.variable_scope("C", reuse=reuse):
            conv1 = conv(X, kernel=[5, 5, 1, 64], stride=[1, 2, 2, 1],
                         activation_fn=leaky_relu, scope="conv1")
            conv2 = conv(conv1, kernel=[5, 5, 64, 128], stride=[1, 2, 2, 1],
                         activation_fn=leaky_relu, scope="conv2")

            convt2, dim = flatten(conv2)
            fc1 = fully_connected(convt2, dim, 256, activation_fn=leaky_relu, scope="fc1")

            # without sigmoid
            logits = fully_connected(fc1, 256, 1, scope="fc2")
            return logits

    def __call__(self, X, p, reuse=False):
        return self.build_model(X, p, reuse)

結果

https://github.com/Wanwannodao/DeepLearning/blob/master/GAN/WGAN/image_0.png?raw=true

https://github.com/Wanwannodao/DeepLearning/blob/master/GAN/WGAN/image_19.png?raw=true

https://github.com/Wanwannodao/DeepLearning/blob/master/GAN/WGAN/loss.png?raw=true

確かに学習が進むに連れて、ロスも減っていくことは確認できたが最初の方の動作がよくわからなかった.
どっかで目にしたことだが、DCGANなどより多少ぼやけているか.