Wasserstein GAN and Beyond

$$ \DeclareMathOperator{\E}{\mathbb{E}} $$ 本文是关于GAN[1]训练稳定性的几篇文章的笔记。 主要包括(但不仅限于)文献[2]分析了原生GAN[1]训练不稳定的原因在于其优化目标等价于 最小化真实数据分布$p_r$和生成数据分布$p_g$之间的J-S散度 $D_{JS}(p_r || p_g)$。 文献[3]提出使用连续性和数值稳定性更好的Wasserstein distance来代替原生GAN中的J-S散度。 文献[4]提出对判别器的参数采用“谱归一化” (spectral normalization,这里翻译成谱归一化很蹩脚), 从而使判别器满足lipschitz continuity。 因为在[2]简单的使用参数剪裁(weight clip)来限制判别器的参数,这样做过于粗暴而且会影响模型收敛。 在GAN原文[1]中整个模型(包括G和D)的优化目标为: \begin{equation} \begin{split} \min_G \max_D V(D, G) &= \E_{x \sim p_r} \left[\log D(x)\right] + \E_{z \sim p_z(z)} \left[\log(1 - D(G(z)))\right] \\ &= \E_{x \sim p_r} \left[\log D(x)\right] + \E_{x \sim p_z{z}} \left[\log(1 - D(G(z)))\right] \\ &= \E_{x \sim p_r} \left[\log D(x)\right] + \E_{x \sim p_g(x)} \left[\log(1 - D(x))\right] \\ &= \int_x \left[p_r(x)\log D(x) + p_g(x)\log(1-D(x))\right]dx \end{split} \label{eq:gan-loss} \end{equation} 这里解释一下$\ref{eq:gan-loss}$式的含义: 对$G$而言最好的状态就是$D(G(Z))=1$,也就是让$D$觉得G生成的图片是真实的, 此时$\ref{eq:gan-loss}$式中右边 $\log(1-D(G(z))) = \log{0} = -\infty$; 对$D$而言就是$D(x) = 1, x \sim p_r$,以及$D(G(z)) = 0$,此时$\ref{eq:gan-loss}$式为$\log{1} + \log{1} = 0$。 因此整个$\ref{eq:gan-loss}$式在$G$上求损失的最小值,在$D$上求损失的最大值。
Read more

PyTorch与caffe中SGD算法实现的一点小区别

最近在复现之前自己之前的一个paper的时候发现PyTorch与caffe在实现SGD优化算法时有一处不太引人注意的区别,导致原本复制caffe中的超参数在PyTorch中无法复现性能。 这个区别于momentum系数$\mu$有关。简单地说,文献[1]和caffe的实现中,学习率放到了参数更新量$v_t$内部; 而在PyTorch的实现里,参数更新量根据momentum和当前损失对参数的梯度$\Delta f(\theta)$,求得,而后在更新参数的时候再乘上学习率。 在学习率$\varepsilon$固定的情况下,这两种实现是等价的。 但通常我们会在训练过程中降低(有部分情况会提高)学习率,这样导致了训练的不稳定。 假设目标函数是 $f(\theta)$ ,目标函数的导数是$\Delta f(\theta_t)$,那么[1]和caffe根据以下公式更新参数 $\theta$ : $$ \begin{equation} v_{t+1} = \mu v_t - \varepsilon \Delta f(\theta_t) \end{equation} $$ $$ \begin{equation} \begin{split} \theta_{t+1} &= \theta_t + v_{t+1}\\ & = \theta_t + \mu \cdot v_t - \varepsilon \cdot \Delta f(\theta) \end{split} \end{equation} $$ (1)式中$\Delta f(\theta_t)$表示目标函数的导数,$\mu$表示momentum的系数(在[1]中被称为velocity),$\varepsilon$表示学习率。 我们先看caffe关于这部分的实现(代码在 https://github.com/BVLC/caffe/blob/99bd99795dcdf0b1d3086a8d67ab1782a8a08383/src/caffe/solvers/sgd_solver.cpp#L232-L234) template <typename Dtype> void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); const vector<float>& net_params_lr = this->net_->params_lr(); Dtype momentum = this->param_.
Read more