论文阅读:《Reformer: The Efficient Transformer》

本论文为谷歌近期发表的对Transformer改进的一篇论文,论文名字中的Efficient Transformer解释了论文的主要目的。过去一些基于Transformer结构的论文,一看到模型的总参数量就让人望而生畏,有些模型在我们的单卡GPU上根本跑不起来,因此就看了一下这篇论文。论文感觉比较偏工程,了解下它的大致思想就好。

参考:
https://arxiv.org/pdf/2001.04451.pdf
https://github.com/google/trax/blob/master/trax/models/reformer/reformer.py
https://zhuanlan.zhihu.com/p/92153420

论文摘要

为了解决Transformer模型在处理长序列时的GPU资源消耗问题,提出了更省内存和更快的Transformer模型结构。其改进主要有两点:

  • 使用locality-sensitive hashing代替dot-product attention,使得计算复杂度由 $O(L^2)$直接降为$O(LlogL)$,其中L为序列长度
  • 使用reversible residual layers来代替传统的残差层,使得训练过程中对激活函数的值的存储由N次降低为1次,其中N是层数。

背景介绍

比较大的Transformer模型里的每一层有0.5Billion的参数,最多可达到64层。并且随着序列长度增加,单个文本train example需要能处理11k左右的token。对于音乐、图像等数据,序列可能会更长,因此有些模型只能在大型GPU集群中进行并行训练。受GPU显存限制,有的模型也很难在单个GPU机器上进行微调。

因此会有这样一个疑问,这么大的Transformer模型到底在哪里消耗了这么多资源?我们不妨计算一下:

  • 每层0.5Billion的参数需要2GB的存储
  • 使用1024 embedding size和8 batch size训练的64k token的激活函数值需要64K 1K 8 = 0.5Billion的参数,即2GB的存储
  • N层网络需要将激活值存储N次(为了back-propagation时进行计算)
  • Feed-forward层的维度通常要比d_model大很多
  • 对于序列长度L来说计算attention所需要的时空复杂度为O(L^2)

具体计算过程如下:

图片

  • Transformer Block

    $h_{m i d}=\text { LayerNorm }\left(h_{i n}+\text { MultiHead }\left(h_{i n}\right)\right)$

    $h_{\text {out}}=\text {LayerNorm }\left(h_{\text {mid}}+\mathrm{FFN}\left(h_{\text {mid}}\right)\right)$

  • Multi-head Attention

$\begin{array}{l}\text {head}_{i}=\text { Attention }\left(Q W_{i}^{Q}, K W_{i}^{K}, V W_{i}^{V}\right) \ \text { Attention }(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right) V \ \text { MultiHead }(Q, K, V)=\text {Concat}\left(\text {head}_{1}, \ldots, \text {head}_{h}\right) W^{O}\end{array}$

  • Attention输入:
    • Q: (batch_size, seq_q, d_model)
    • K: (batch_size, seq_k, d_model)
    • V: (batch_size, seq_k, d_model)
  • Attention输出:
    • (batch_size, q_seq_len, d_model)

图片

  • Feed-Forward

    $\operatorname{FFN}(h)=\operatorname{ReLU}\left(h W_{1}+b_{1}\right) W_{2}+b_{2}$

  • FFN输入:

    • (batch_size, q_seq_len, d_model)
  • FFN输出:

    • (batch_size, q_seq_len, d_ff)

本论文提出了Reformer模型,使用了下面方法解决了内存和速度的问题:

  • Reversible layers:整个模型只需保存一次activations,使因层数导致的内存问题解决
  • FFN层分块并行处理:降低d_ff产生的内存消耗
  • 局部敏感哈希(locality-sensitive hashing):代替dot-product attention带来的O(L^2)计算和内存复杂度,使得能处理更长的序列

局部敏感哈希

参考:
https://arxiv.org/abs/1509.02897
https://www.cnblogs.com/maybe2030/p/4953039.html

Attention计算中最耗时和消耗内存的是QK^T([batch size, length, length])。我们其实关注的是softmax(QK^T),而softmax的取值主要被其中较大的元素主导,因此对Q的每个向量qi,只需要关注K中哪个向量最接近qi。比如说如果K的长度是64K,对于每个qi,我们只需要关注其中跟qi距离最近的32或64个kj。

图片

我们首先想到的是 locality-sensitive hashing,其特点是对于每个向量x,在经过哈希函数h(x)后,在原来的空间中挨的近的向量有更大的概率获得相同的哈希值。就像上面这张图,经过旋转(映射)后,距离远得点(第一行)有很大概率分到不同得桶中,而距离近得点(第二行)很大概率分到相同得桶中。

在实现时我们使用了一个随机产生的大小为(dk, b/2)的矩阵R,定义$h(x)=\arg{\max }([x R ;-x R])​$为哈希函数,这样所有x,可以把它们分配到b个哈希桶里。具体的计算和证明在另一片论文(Practical and Optimal LSH for Angular Distance)中。

下面这张图说明了LSH具体的计算流程:

图片

在上图中,不同的颜色表示不同的哈希值,相似的词则具有相同的颜色。分配哈希值后,序列重新排列,将具有相同哈希值的元素放在一起,再分为多个片段(或多个区块)以实现并行处理。然后在这些短得多的区块(及其相近邻块以覆盖溢出)内应用注意力,从而大大降低计算负载。

上图右侧(a-b)是和传统注意力的比较。(a)表明传统的注意力是很稀疏的,也就是说大多数的字符其实都不用关注;(b) k和q根据它们的哈希桶(注意随机的那个矩阵R是共享的)排序好,然后再使用。

由于哈希桶的大小很可能不均匀,所以我们首先令$k_{j}=\frac{q_{i}}{\left|q_{i}\right|}​$来保证$h\left(k_{j}\right)=h\left(q_{j}\right)​$,然后再从小到大给Q的哈希桶排序,在每个桶内部,按照位置先后排序。这实际上定义了一个置换$i \mapsto s_{i}​$。在排序后的注意力矩阵中,来自同一个哈希桶的(q,k)对会聚集在矩阵的对角(上图右c)。最后,把它们分组,每组m个,在各组内相互关注。

为了进一步减小桶分布不均的情况,可以用不同的哈希函数进行多轮哈希。下表是几种注意力方式的时空复杂度:(l: 序列长度,b: batch_size, $n_h$: num of heads, $n_c$: num of LSH chunk, $n_r$: num of hash repetition)图片

可逆层

参考:
https://arxiv.org/pdf/1707.04585.pdf
https://zhuanlan.zhihu.com/p/60479586
https://www.cnblogs.com/gczr/p/12181354.html

通过LSH可以将attention的复杂度减少为序列长度的线性级,但是参数量占的复杂度依旧很高,我们想要进一步减少。在上面表中我们看出,每一层的输入前都至少有$b \cdot l \cdot d_{\text {model}}$的激活输出值,$n_l$层则至少有个$b \cdot l \cdot d_{\bmod e l} \cdot n_{l}$。而且光是FFN层就会产生$b \cdot l \cdot d_{f f} \cdot n_{l}$的激活输出,对于一些大模型,这个$d_ff$会比较大(4K甚至64K),甚至消耗掉16GB的内存。因此采用可逆层来解决$n_l$和$d_ff$的问题。

可逆Transformer

可逆残差网络的前向传播和反向计算过程如下图:

图片

前向:

$\begin{array}{l}y_{1}=x_{1}+\mathcal{F}\left(x_{2}\right) \ y_{2}=x_{2}+\mathcal{G}\left(y_{1}\right)\end{array}​$

逆向:

$\begin{array}{l}x_{2}=y_{2}-\mathcal{G}\left(y_{1}\right) \ x_{1}=y_{1}-\mathcal{F}\left(x_{2}\right)\end{array}$

在典型的残差网络中,通过网络传递的输入将会向堆栈中的每一层不断添加至向量。相反,可逆层中每个层有两组激活。一组遵循刚才描述的标准过程,从一层逐步更新到下一层,但是另一组仅捕获第一层的变更。因此,若要反向运行网络,只需简单地减去每一层应用的激活。

简单来说,可逆层将输入分成两部分,使得每一层的值可以由它下一层的输出推导出来。因此整个网络只需要存储最后一层的值即可。

具体的解释可参考论文:The Reversible Residual Network: Backpropagation Without Storing Activations.

在Transformer中我们这样应用可逆层:

$Y_{1}=X_{1}+\text { Attention }\left(X_{2}\right)​$

$Y_{2}=X_{2}+\text { FeedForward }\left(Y_{1}\right)$

FF层分组

由于FFN层的计算不依赖于位置信息,可以将计算进行分块处理:$Y_{2}=\left[Y_{2}^{(1)} ; \ldots ; Y_{2}^{(c)}\right]=\left[X_{2}^{(1)}+\text { FeedForward }\left(Y_{1}^{(1)}\right) ; \ldots ; X_{2}^{(c)}+\text { FeedForward }\left(Y_{1}^{(c)}\right)\right]$

论文中特别强调,虽然通过分块和可逆层使得激活值是独立于层数的,但是对参数来说可不是这样,参数会随着层的增长而增长。好在我们可以利用CPU的内存,在逐层计算时将暂不使用的参数存储到CPU内存中,当需要时再交换回来。虽说从GPU到CPU的传输是比较慢的,但这对于Reformer来说,其batch_size * lenth已经达到可以忽略到这种参数传输的成本。

下表是所有变体的复杂度:

图片

实验效果

参考:
https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/image_generation.ipynb
https://colab.research.google.com/github/google/trax/blob/master/trax/models/reformer/text_generation.ipynb
https://colab.research.google.com/github/google/trax/blob/master/trax/intro.ipynb

论文中的实验结论主要是为了证实Reformer可以更高效,且对精度几乎没有损失。这里贴一张Colab上对Reformer应用的效果图:

图片

上图使用Reformer逐像素生成全画幅图像。