PyTorch与caffe中SGD算法实现的一点小区别

Posted on

最近在复现之前自己之前的一个paper的时候发现PyTorch与caffe在实现SGD优化算法时有一处不太引人注意的区别,导致原本复制caffe中的超参数在PyTorch中无法复现性能。

这个区别于momentum系数$\mu$有关。简单地说,文献[1]和caffe的实现中,学习率放到了参数更新量$v_t$内部; 而在PyTorch的实现里,参数更新量根据momentum和当前损失对参数的梯度$\Delta f(\theta)$,求得,而后在更新参数的时候再乘上学习率。 在学习率$\varepsilon$固定的情况下,这两种实现是等价的。 但通常我们会在训练过程中降低(有部分情况会提高)学习率,这样导致了训练的不稳定。

假设目标函数是 $f(\theta)$ ,目标函数的导数是$\Delta f(\theta_t)$,那么[1]和caffe根据以下公式更新参数 $\theta$ :

$$ \begin{equation} v_{t+1} = \mu v_t - \varepsilon \Delta f(\theta_t) \end{equation} $$
$$ \begin{equation} \begin{split} \theta_{t+1} &= \theta_t + v_{t+1}\\ & = \theta_t + \mu \cdot v_t - \varepsilon \cdot \Delta f(\theta) \end{split} \end{equation} $$

(1)式中$\Delta f(\theta_t)$表示目标函数的导数,$\mu$表示momentum的系数(在[1]中被称为velocity),$\varepsilon$表示学习率。

我们先看caffe关于这部分的实现(代码在 https://github.com/BVLC/caffe/blob/99bd99795dcdf0b1d3086a8d67ab1782a8a08383/src/caffe/solvers/sgd_solver.cpp#L232-L234)

template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
  const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
  const vector<float>& net_params_lr = this->net_->params_lr();
  Dtype momentum = this->param_.momentum();
  Dtype local_rate = rate * net_params_lr[param_id];
  // Compute the update to history, then copy it to the parameter diff.
  switch (Caffe::mode()) {
  case Caffe::CPU: {
    caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());
    caffe_copy(net_params[param_id]->count(),
        history_[param_id]->cpu_data(),
        net_params[param_id]->mutable_cpu_diff());
    break;
  }
  case Caffe::GPU: {
#ifndef CPU_ONLY
    sgd_update_gpu(net_params[param_id]->count(),
        net_params[param_id]->mutable_gpu_diff(),
        history_[param_id]->mutable_gpu_data(),
        momentum, local_rate);
#else
    NO_GPU;
#endif
    break;
  }
  default:
    LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
  }
}

函数ComputeUpdateValue主要用于计算最后参数的更新值 ,也就是(2)式中的$v_{t+1}$。我们重点关注一下部分代码:

caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
              net_params[param_id]->cpu_diff(), momentum,
              history_[param_id]->mutable_cpu_data());

这里axpby就是$\alpha \cdot x + \beta \cdot y$,对应着local_rate就是学习率(之所以有local是因为caffe可以逐层设置学习率系数)。 net_params[param_id]->cpu_diff()就是参数的导数,也就是(1)式中的 $\Delta f(\theta_t)$。 history_[param_id]->mutable_cpu_data()也就是历史累计的momentum,对应的是$v_t$


我们再来看看PyTorch相关部分的代码(代码链接https://github.com/pytorch/pytorch/blob/9679fc5fcd36248ffe67f70d5c135d7af8ba0e2b/torch/optim/sgd.py#L88-L105):

def step(self, closure=None):
        """Performs a single optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
                        buf.mul_(momentum).add_(d_p)
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

这里d_p是参数的导数,可以看到PyTorch的实现和(1)(2)式不太一样,PyTorch是按照下面的规则更新参数的:

$$ \begin{equation} v_{t+1} = \mu v_t - \Delta f(\theta_t) \end{equation} $$
$$ \begin{equation} \begin{split} \begin{split} \theta_{t+1} &= \theta_t + v_{t+1}\\ & = \theta_t + \mu \cdot v_t - \varepsilon \cdot \Delta f(\theta) \end{split} \end{split} \end{equation} $$ 为了方便比较,我们把(1)(2)式搬运过来:
$$ v_{t+1} = \mu v_t - \varepsilon \Delta f(\theta_t) $$
$$ \begin{split} \theta_{t+1} &= \theta_t + v_{t+1}\\ & = \theta_t + \mu \cdot v_t - \varepsilon \cdot \Delta f(\theta) \end{split} $$

可以看出来,相对于caffe的实现,PyTorch真正的momentum系数相当于caffe的momentum再乘以学习率$\varepsilon$。 因此使用PyTorch的时候,当学习率非常小(比如像我这样使用类似FCN结构的网络,学习率<1e-6),那么实际上的有效momentum是非常小的。

我不知道PyTorch是基于什么样的考虑要这样设计,文档中倒是有说这个区别,但是并没有解释 (文档链接

This is an image

最后,在[2] (Accurate, Large Minibatch SGD:Training ImageNet in 1 Hour)中,作者也提到过这两种不同实现的区别, 并提出了momentum correction来补偿PyTorch实现到来的不稳定性,相关内容如下:


[1] Sutskever, Ilya, et al. “On the importance of initialization and momentum in deep learning.“International conference on machine learning. 2013.

[2] Goyal, Priya, et al. “Accurate, large minibatch SGD: training imagenet in 1 hour.” arXiv preprint arXiv:1706.02677 (2017).


原文链接:http://kaiz.xyz/blog/posts/momentum-caffe-pytorch/,谢绝任何形式转载。