Wasserstein GAN (WGAN)
Wasserstein GAN (WGAN)
[1701.07875] Wasserstein GAN
([1701.04862] Towards Principled Methods for Training Generative Adversarial Networks WGANの話の前にこの話がある)
Martin Arjovsky氏の実装(Torch)
GitHub - martinarjovsky/WassersteinGAN
WGANをTensorFlowで実装した
github.com
論文まとめ
深く理解しようとするとかなり数学的知識が要求されるので、直感的な理解だけを追っていくことにする.
そもそもVAEやGANのアプローチというのは、自分の分布を真の分布に近づけていくというものだが、この分布間の距離/ダイバージェンスの定義の仕方によって、学習の収束性、安定性に差異がある.
基本的な距離/ダイバージェンスの定義としては以下の4つがある.
とはあるコンパクト空間上で定義される確率分布.(コンパクト空間は距離空間においては、有界閉集合をイメージすればよいらしい. ここでは画像の空間( )と思っていてよい)
- Total Variation (TV) distance
省略
- Kullback-Leibler (KL) divergence
非対称、発散する可能性がある
- Jensen-Shannon (JS) divergence
対称、常に定義可能
GANは近似的にこれを最小化することに対応する
- Earth-Mover (EM) distance (Wasserstein-1)
の直感的理解は、をに変えるために、地点から地点にどれだけの"質量"を移動しなければならないかという量.
よってEM距離は直感的には最適な輸送コストと考えれば良い.
(は同時分布を表す)
以下JSとEMだけ考えることにする.
次の簡単な状況下での距離の変化を観測してみる.
(一様分布)
: 確率分布
: パラメータ
簡単に言うと、二直線間の距離みたいなイメージか.
すると各距離は以下の通り.
よって下図のようなグラフが描ける.
(左がEM、右がJS)
(論文より引用)
この図からわかることは、EM距離は連続かつ有用な勾配(距離の直感に合う)、JS距離は非連続かつ距離の直感に合わない
論文では、諸定理を証明してEM距離が他の尺度より有利な性質を持っていることを述べている.
どうにかEM距離をロス関数として使いたいのだが、定義通りだと難しい...
どうするか??
まず、次のような双対問題が存在する.
- Kantorovich-Rubinstein duality
はある定数で、関数は-Lipschitz連続
-Lipschitz連続性の直感的な理解としては、関数の傾きがある定数で抑えられるような一様連続性であると思っている.
であるから、以下のような問題を考えればよい
はパラメータ化された関数で、GANの枠組みではDiscriminatorに対応する(論文中ではCritic)であり、-Lipschitz連続とする.
もし、先程の双対問題の上限がによって得られるならば、この最大化問題での定数倍を計算することができる.
は、に従って、ある分布(ex. Gaussian)に従う確率変数を(画像の空間)にマッピングする関数で、GANのGeneratorに対応するもの. (これが連続ならば、も連続となることも証明されている.)
さて、に関する勾配が以下になることも定理として示されている.
利点
- 学習の収束性、安定性
- 意味のある、しかも出力画像の質と相関するロス
- DiscriminatorとGeneratorの学習の進み具合を調整する必要がない
- Mode Collapseがない
実装
上でいろいろ書いたが実装に際してはやることは至って簡単.
(Critic = Discriminator)
- Criticを何回か更新して、の良い近似をする
- (Critic)の-Lipschitz連続性を保つ
- Generatorをを最小化する方向に更新する.
の計算
これは、上で述べた双対問題から出てきた最大化問題を解くので、以下の確率的勾配を昇る方向に更新する.
よって実装上は、符号反転して最小化問題とする.
(Critic)の-Lipschitz連続性
パラメータ空間がコンパクト空間(有界閉集合)なら(Critic)の-Lipschitz連続であることを意味する(らしい).
これは、weight clippingをするだけで達成可能である.
を最小化する
GeneratorはEM距離を最小化するので、次の確率的勾配を降る方向に更新する.
以上をまとめると以下のような雰囲気になる.
self.gen_img = self.G() # g(z) g_logits = self.C(self.gen_img, self.p) # f(g(z)) self.g_loss = -tf.reduce_mean(g_logits) # Generatorのロス self.c_loss = tf.reduce_mean(-self.C(self.X, self.p, reuse=True) + g_logits) # Criticのロス c_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) g_opt = tf.train.RMSPropOptimizer(learning_rate=5e-5) c_grads_and_vars = c_opt.compute_gradients(self.c_loss) g_grads_and_vars = g_opt.compute_gradients(self.g_loss) # Criticに関する勾配のみ c_grads_and_vars = [[grad, var] for grad, var in c_grads_and_vars \ if grad is not None and var.name.startswith("C") ] # Generatorに関する勾配のみ g_grads_and_vars = [[grad, var] for grad, var in g_grads_and_vars \ if grad is not None and var.name.startswith("G") ] self.c_train_op = c_opt.apply_gradients(c_grads_and_vars) self.g_train_op = g_opt.apply_gradients(g_grads_and_vars) # Weight Clipping self.w_clip = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) \ for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="C")]
# 1 Step内 # よい近似となるまで for n in range(5): batch_xs, _ = mnist.train.next_batch(64) batch_xs = batch_xs.reshape([-1, 28, 28, 1]) feed = {wgan.X: batch_xs} # Critic更新 _ = sess.run(wgan.c_train_op, feed_dict=feed) # Weight Clipping _ = sess.run(wgan.w_clip) # Generator更新 _ = sess.run(wgan.g_train_op)
ちなみに、DCGANの反省活かして少しはきれいな実装になったかと思う.
CriticとGeneratorのモデルは以下の通り.
(Batch Normalizationがなくてもよい結果は得られる、更にMLPでもよいとは書かれている)
ラッパーを色々用意したので割りとすっきり書けた(そんなことしなくてもTF-Slimを使えばいい...)
# ==================== # Generator # ==================== class Generator(): def __init__(self, batch_size): self.batch_size = batch_size def build_model(self, reuse): with tf.variable_scope("G", reuse=reuse): z = tf.random_uniform([self.batch_size, 100], minval=-1.0, maxval=1.0) fc1 = tf.nn.relu(batch_norm(fully_connected(z, 100, 1024, scope="fc1"), axes=[0])) fc2 = tf.nn.relu(batch_norm(fully_connected(fc1, 1024, 128*7*7, scope="fc2"), axes=[0])) fc2 = tf.reshape(fc2, [-1, 7, 7, 128]) convt1 = tf.nn.relu(batch_norm(convt(fc2, kernel=[5, 5, 64, 128], stride=[1, 2, 2, 1], output=[self.batch_size, 14, 14, 64], scope="convt1"), axes=[0, 1, 2])) convt2 = convt(convt1, kernel=[5, 5, 1, 64], stride=[1, 2, 2, 1], output=[self.batch_size, 28, 28, 1], activation_fn=tf.nn.tanh, scope="convt2") return convt2 def __call__(self, reuse=False): return self.build_model(reuse) # ==================== # Critic # ==================== class Critic(): def __init__(self, batch_size): self.batch_size = batch_size def build_model(self, X, p, reuse=False): with tf.variable_scope("C", reuse=reuse): conv1 = conv(X, kernel=[5, 5, 1, 64], stride=[1, 2, 2, 1], activation_fn=leaky_relu, scope="conv1") conv2 = conv(conv1, kernel=[5, 5, 64, 128], stride=[1, 2, 2, 1], activation_fn=leaky_relu, scope="conv2") convt2, dim = flatten(conv2) fc1 = fully_connected(convt2, dim, 256, activation_fn=leaky_relu, scope="fc1") # without sigmoid logits = fully_connected(fc1, 256, 1, scope="fc2") return logits def __call__(self, X, p, reuse=False): return self.build_model(X, p, reuse)