1.摘要

多头注意力(MHA)是Transformer的关键组件。在MHA中,注意力头独立工作,导致诸如注意力得分矩阵的低秩瓶颈和头部冗余等问题。我们提出了动态可组合多头注意力(DCMHA),这是一种参数和计算高效的注意力架构,
解决了MHA的不足,并通过动态组合注意力头来增加模型的表达能力。DCMHA的核心是一个组合函数,该函数以输入相关的方式转换注意力得分和权重矩阵。
DCMHA可以作为任何Transformer架构中的MHA的直接替换,从而获得相应的DCFormer。DCFormer在不同架构和模型规模下的语言建模中显著优于Transformer,并且性能与计算量为1.7倍至2.0倍的模型相当。

2.引言

Transformer(Vaswani等,2017)已成为各种领域和任务的最先进模型,并成为基础模型的实际骨干。多头注意力(Multi-head Attention, MHA)是Transformer的关键组件,负责在标记之间进行信息交换。
MHA允许模型在不同位置同时关注来自不同表示子空间的信息。MHA的一个重要特征是多个注意力头并行且相互独立地工作。虽然它很简单,经验上成功的这一选择带来了一些缺点,
例如注意力得分矩阵的低秩瓶颈降低了表达能力,以及冗余头的问题导致参数和计算的浪费。
已有许多工作尝试通过引入头之间某种形式的交互或协作来改进多头注意力机制。我们从头组成的角度对这些工作进行分类,并以此为出发点来阐述我们的工作:从固定数量的“基础头”中组合出新的头。
头组合可以在多头注意力机制(MHA)的计算图中不同位置进行。头组合的一种常见形式是使用更复杂的方法来组合/选择多个头的输出,以替代MHA简单的拼接后投影方法,
这种方法可以是静态的(Ahmed等,2017)或动态的(Li等,2019;Zhang等,2022)。在MHA计算的最顶层操作时,这只是“表面”的组合形式:各个头仍然独立操作,标记之间的底层信息流保持不变。
由于这种性质,它通常是轻量且高效的,但同时获得的表达能力提升是有限的。
在最低层,另一种对比方法是组成MHA中头部的线性投影WQ、WK和WV(Cordonnier等,2020;Liu等,2022)。
投影组合允许真正新的头部通过实际改变的信息流来组成,这是提高表达能力的根本改进的前提。尽管理论上通过在投影之间共享参数更加参数高效,但这种方法在实践中通常会带来较大的计算成本。
此外,这种组合是静态的,缺乏对输入的适应性。头部组合的潜力无法完全实现。
我们选择采用第三种中间路径方法来组合注意力分数矩阵和/或注意力权重矩阵(在本文中统称为注意力矩阵)(Shazeer等,2020;Wang等,2022;Nguyen等,2022)。
注意力矩阵的组合与投影组合具有一些等效关系,确保了相对于头输出组合在表达能力上的基本提升。由于计算成本小于投影组合,并且经过精心设计,因此得益于此。
我们可以使组合具有动态性:根据输入新头可以即时组成,进一步提高模型的表现力。与现有工作相比,我们寻求同时满足真正组合性、动态性和效率的要求。
在这项工作中,我们提出了动态可组合多头注意力机制(DCMHA),这是一种参数和计算高效的注意力架构,解决了多头注意力机制(MHA)的不足,并通过动态组合注意力头来增加模型的表达能力。
DCMHA的核心是一个组合函数,该函数以输入相关的方式转换注意力得分和权重矩阵。DCMHA可以作为任何变换器架构中MHA的直接替换,从而获得相应的DCFormer。
我们实现了DCMHA/DCFormer,并进行了分析和广泛的实验来评估其有效性、效率和可扩展性。
实验结果表明,DCFormer在不同架构(原始或先进的LLaMA架构)和模型规模(从4.05亿到69亿)的语言建模中显著优于Transformer,并且性能与计算量为1.7倍至2倍的模型相当。
例如,DCPythia-6.9B在预训练困惑度和下游任务评估上都优于开源的Pythia-12B。我们还将DCMHA应用于图像分类的视觉变换器,并使用合成数据集对DCPythia-6.9B模型进行了一些初步分析,以更好地理解DCMHA的工作原理及其原因。

3.通过变换注意力矩阵进行头部组合

假设 $T$ 和 $S$ 分别是查询序列和键序列的长度。我们用 $A_h \in \mathbb{R}^{T \times S}$ 表示第 $h$ 个头的注意力矩阵,该矩阵可以是多头注意力模块(MHA)中 $H$ 个头中的一个的注意分数矩阵(softmax 之前)或权重矩阵(softmax 之后)。我们可以将 $H$ 个注意力矩阵堆叠到一个张量中:
$$
A = \text{Stack}({A_h}_{h=1}^H) \in \mathbb{R}^{H \times T \times S}.
$$
我们用:
$$
A[:,i,j] = A[:,i,j] \in \mathbb{R}^H
$$
表示查询向量 $Q_i$ 和键向量 $K_j$ 之间的一对注意力向量。
通过注意力矩阵组合(attention matrix composition),我们可以将 $H$ 个新的注意力矩阵 ${A’h}{h=1}^H$ 组合如下:第 $h$ 个组合后的矩阵 $A’h$ 是 $H$ 个基矩阵的线性组合:
$$
A’h = \sum{j=1}^H C
{hj} A_j,
$$
其中 $C \in \mathbb{R}^{H \times H}$ 是组合映射(composition map)。

定理1

通过组合映射 $C \in \mathbb{R}^{H \times H}$ 对注意力分数 ${A_i}_{i=1}^H$ 的组合等价于 QK 投影组合,且满足公式中定义的 $H$ 倍扩展。
类似地,我们有以下定理,说明注意力权重矩阵组合与以下 OV 投影组合之间的等价关系:

$$
\widetilde{W}^V_i = \text{Concat}{j \in [H]} \left[C{ij} W^V_j\right], \quad \widetilde{W}^O = \text{Tile}(W^O, (H, 1)) \tag{2}
$$

其中:
$$
\widetilde{W}^V_i \in \mathbb{R}^{D_m \times H D_h}, \quad \widetilde{W}^O \in \mathbb{R}^{H H D_h \times D_m}
$$
是组合后的投影矩阵。$\text{Tile}(W^O, (H, 1))$ 表示将 $W^O$ 沿其第一维重复 $H$ 次。

定理2

通过组合映射 $C \in \mathbb{R}^{H \times H}$ 对注意力权重 ${A_i}_{i=1}^H$ 的组合等价于 OV 投影组合,且满足公式 (2) 中定义的 $H$ 倍扩展。

注意力矩阵组合与基于扩展的投影组合的关系进一步验证了其有效性:
Bhojanapalli 等人(2020)表明,增大 QK 投影的头部维度可以缓解注意力分数矩阵的低秩瓶颈问题。
类似地,增大 OV 投影的头部维度可以增加头部之间的跨 token 信息传输带宽。因此,通过对注意力分数和注意力权重的组合,可以从根本上改进模型表达能力。

4.动态可组合多头注意力

在 MHA 中,注意力向量 $A_{i,j}$ 控制了查询向量 $Q_i$ 和键向量 $K_j$ 之间的信息流。在 DCMHA 的核心是一个 Compose 函数,该函数在给定 $Q_i$ 和 $K_j$ 的情况下,将它们的注意力向量 $A_{i,j} \in \mathbb{R}^H$ 转换为一个新的注意力向量 $A’{i,j}$,其公式如下:
$$
A’
{i,j} = \text{Compose}(A_{i,j}, Q_i, K_j; \theta) \tag{3}
$$
从高层次上来看,为了实现 DCMHA,我们只需在 MHA 的计算中插入两个 Compose 函数,其中一个作用在 softmax 之前的注意力分数张量 $A^S$ 上,另一个作用在 softmax 之后的注意力权重张量 $A^W$ 上:
$$
A^S_i = \frac{Q W^Q_i (W^K_i)^T}{\sqrt{D_h}}, \quad A^S = \text{Stack}(A^S_1, \ldots, A^S_H)
$$

$$
A^S = \text{Compose}(A^S, Q, K; \theta_\text{pre}) \tag{4}
$$

$$
A^W = \text{Softmax}(A^S, \text{dim} = -1)
$$

$$
A^W = \text{Compose}(A^W, Q, K; \theta_\text{post})
$$

$$
O_i = A^W_i (V W^V_i), \quad O = \text{Concat}(O_1, \ldots, O_H) W^O
$$

其中:

  • $W^Q_i, W^K_i, W^V_i$ 是第 $i$ 个头的投影矩阵。
  • $W^O \in \mathbb{R}^{H D_h \times D_m}$ 是输出投影矩阵。
  • $\text{Stack}$ 表示沿第一维堆叠,$\text{Concat}$ 表示沿最后一维拼接。

我们在公式 (3) 中使用了“批量版”Compose,即当查询和键的序列长度分别为 $T$ 和 $S$ 时,将它们打包到矩阵 $Q \in \mathbb{R}^{T \times D_m}$ 和 $K \in \mathbb{R}^{S \times D_m}$ 中,并将它们的注意力张量 $A \in \mathbb{R}^{H \times T \times S}$ 转换为具有相同形状的新张量。

下面描述 Compose 函数的计算过程。$A_{i,j}$ 依次通过五个分支,并汇聚到一起:

  1. 第一分支:$A_{i,j}$ 首先通过一个加权矩阵 $\widetilde{W}_b$,其计算独立于 $Q_i$ 或 $K_j$。
  2. 第二分支:$A_{i,j}$ 首先通过 $w_{q1} \in \mathbb{R}^{H \times R}$ 被投影到一个较低的维度 $R$,然后通过 $w_{q2} \in \mathbb{R}^{R \times H}$ 恢复到原来的维度 $H$。
  3. 动态权重 $w_{q1}$ 和 $w_{q2}$ 是通过 $Q_i$ 计算的。这样,模型可以对头部间的信息共享建模。
    通过设置 $R \ll H$(本文中 $R = 2$),我们假设尽管头部之间可以以多种方式共享,但对于任意一对查询和键,仅需少量的共享模式即可满足。这种情况下,在第三个分支中,$A_{i,j}$ 按元素乘以一个由 $Q_i$ 计算出的门控权重 $w_{qg} \in \mathbb{R}^H$。该分支控制了在给定查询的情况下,各头部保留或遗忘其原始分数的程度。

计算动态投影权重 $w_{q1}$ 和 $w_{q2}$,从 $Q_i$ 出发,我们使用一个具有单隐藏层和 GELU 激活函数的前馈网络(FFN),其参数为 $W_{q1} \in \mathbb{R}^{D_m \times I}$ 和 $W_{q2} \in \mathbb{R}^{I \times R}$,其中 $I = 2HR$。我们在头部数量维度上对 $w_{q1}$ 应用了 RMSNorm(未缩放),以稳定训练,然后将其与 $A_{i,j}$ 相乘。具体计算如下:

$$
w_{q1}, w_{q2} = \text{Chunk}(\text{GELU}(Q_i W_{q1}) W_{q2}, \text{dim} = 1)
$$

$$
w_{q1} = \text{Rmsnorm}(\text{Reshape}(w_{q1}, (H, R)), \text{dim} = 0) \tag{5}
$$

$$
w_{q2} = \text{Reshape}(w_{q2}, (R, H))
$$

计算动态门控权重 $w_{qg}$,从 $Q_i$ 出发,我们通过一个线性投影(参数为 $W_{qg} \in \mathbb{R}^{D_m \times H}$),并加上一个 $\text{tanh}$ 非线性函数,计算门控权重 $w_{qg}$:

$$
w_{qg} = \text{tanh}(Q_i W_{qg}) \tag{6}
$$

对于 $K_j$,还有两个对称的分支,其计算过程与 $Q_i$ 相同。五个分支的输出被汇总以得到最终的更新向量:

$$
A’{i,j} = A{i,j} W_b + A_{i,j} w_{q1} w_{q2} + A_{i,j} w_{qg} + A_{i,j} w_{k1} w_{k2} + A_{i,j} w_{kg} \tag{7}
$$

DCMHA 的可训练参数为:

$$
\theta = {W_b, W_{q1}, W_{q2}, W_{qg}, W_{k1}, W_{k2}, W_{kg}}
$$

这些参数与模型的其他参数一起进行端到端的学习。

4.1 张量分解视角

为了实现动态头部组合,我们需要 $T \times S$ 的变换矩阵(即组合映射),其形状为 $H \times H$,用于每一对 $Q_i$ 和 $K_j$。换句话说,我们需要计算一个与输入相关的 4 维变换张量 $W \in \mathbb{R}^{T \times S \times H \times H}$ 并将其应用于 3 维注意力张量 $A \in \mathbb{R}^{H \times T \times S}$。

尽管理论上有多种方法可以做到这一点,但不同的方法可能在效率上有所不同。上述 Compose 的计算等价于将 $W$ 分解为两层,以优化参数和计算效率:

$$
A’{i,j} = A{i,j} W_{ij}, \quad i \in [1, T], \ j \in [1, S]
$$

$$
W = W_b + \text{ED}(\mathcal{W}_q, \text{dim}) + \text{ED}(\mathcal{W}_k, \text{dim})
$$

这里:

  • $W_b \in \mathbb{R}^{H \times H}$ 是独立于输入的静态权重。
  • $\mathcal{W}_q \in \mathbb{R}^{T \times H \times R}$ 和 $\mathcal{W}_k \in \mathbb{R}^{S \times H \times R}$ 是动态矩阵,分别在行(row-wise)和列(column-wise)上变化。

公式 (8) 展示了公式 (7) 的“批量版”:

$$
A’{i,j} = A{i,j} W_b + A_{i,j} \mathcal{W}q w{q2} + A_{i,j} w_{qg} + A_{i,j} \mathcal{W}k w{k2} + A_{i,j} w_{kg} \tag{8}
$$

变换张量 $W$ 被分解为以下几部分的总和:

  1. 一个二维张量 $W_b \in \mathbb{R}^{H \times H}$,它是输入无关的静态权重。
  2. 两个三维张量 $\mathcal{W}_q$ 和 $\mathcal{W}_k$,它们分别与查询 $Q_i$ 和键 $K_j$ 的行和列相关联。
  • $\text{ED}(\cdot, \text{dim})$ 表示沿维度的扩展(ExpandDims)操作。
  • 动态矩阵 $\mathcal{W}_q$ 和 $\mathcal{W}_k$ 分别在行(row-wise)和列(column-wise)上变化,用于捕捉输入依赖的特性。

我们称这种分解为 行加列组合分解(row-plus-column composition)
这样,3 维张量被分解为两个低秩张量的总和(一个针对行,另一个针对列),从而降低了计算复杂度。

  • $\mathcal{W}{q/k1} \in \mathbb{R}^{T/S \times H \times R}$ 和 $\mathcal{W}{q/k2} \in \mathbb{R}^{T/S \times R \times H}$ 表示动态张量。
  • 一个对角张量则由二维张量 $\mathcal{W}_{q/k g} \in \mathbb{R}^{T/S \times H}$ 填充。

这种分解形式在其他地方也有使用(例如 Zhao et al., 2016; Gu et al., 2021a),被称为 低秩加对角分解(low-rank plus diagonal decomposition)。
分解相关的特点如下:

  • 行加列分解:低秩加对角分解允许分别对注意力张量 $A$ 应用 $\mathcal{W}_q$ 和 $\mathcal{W}_k$,而无需显式构造大的 4 维张量 $W$。
  • 尺寸优化:该分解减少了作用于注意力向量上的变换矩阵的大小,从 $H^2$ 降至 $2HR + H$。
  • 高效计算:得益于该分解,结果张量的大小远小于 $W$,因此可以从输入的查询和键高效计算。例如:

$$
\mathcal{W}{qg} = \text{tanh}(Q W{qg}) \tag{9}
$$

这是公式 (6) 的批量版本。

4.2 张量并行训练的分组组合

在威震天式张量并行(TP)训练中,将一层的注意头分成若干组,并将每组注意头放置在一个节点上,在该节点上执行这些注意头的计算。在典型的TP设置中(例如H = 32, TP = 4),可以在每组8个头像中组成头像,而不是在所有32个头像中组成头像(每组头像的动态投影等级R可以设置为较小的值1)。由于组成在每个节点内局部进行,因此头像之间没有跨组交互,因此没有DCMHA引入额外的跨节点通信。这种分组组合可以通过对Compose的简单修改来实现。我们实现了分组DCMHA,并通过TP培训对其进行了测试。通过实证研究,我们发现分组构图和全头部构图在性能和速度上差异不大。

4.3 复杂性分析

额外参数的比例,即 DCMHA 的 $\theta_{\text{pre}}$ 和 $\theta_{\text{post}}$ 的参数量与整个模型的参数量之比,与头部维度 $D_h$ 成反比,并且在常用的 $D_h$ 值(例如 $128$)下可以忽略不计。
额外计算的比例与 $D_h$ 成正比,随着 $\rho = S / D_m$ 的增加而增加,其中 $S$ 是序列长度。同时,对于足够大的模型(例如参数量 $\geq 6.9B$,$D_m \geq 4096$)以及典型的 $S$ 值(例如 $2048 \leq S \leq 8192$),额外计算的比例仍然非常小。

5.挑战与潜在困难

计算开销:尽管DCMHA在性能上有所提升,但引入的额外操作导致了训练和推理过程中的计算开销增加。作者通过实验表明,这种开销随着模型规模的增大而减少,但在较小规模模型中仍然较为明显。
适应现有模型:由于DCMHA与传统MHA在头投影统计上有较大差异,难以通过微调将预训练的Transformer模型直接转换为DCFormer。作者通过集成梯度归因分析发现,低层头组合对模型性能的影响更大,但预训练模型在微调时低层的更新相对较小,限制了性能提升。
并行训练:在大规模并行训练中,DCMHA需要在每个节点内进行局部组合,避免跨节点通信,这增加了实现的复杂性。