学习私有神经语言模型与专注聚合
1.摘要
移动键盘建议通常被视为单词级别的语言建模问题。集中式机器学习技术需要收集大量用户数据进行训练,这可能引发与用户敏感数据相关的隐私问题。联邦学习(FL)提供了一种有希望的方法,通过在分布式客户端上训练模型而不是在中央服务器上进行训练,来实现对智能个性化键盘建议的私有语言建模。为了获得用于预测的全局模型,现有的FL算法简单地对客户端模型求平均,并且在模型聚合过程中忽略了每个客户端的重要性。此外,在中央服务器上优化学习一个泛化良好的全局模型没有得到优化。为了解决这些问题,我们提出了一种新颖的聚合方法,该方法考虑了客户端模型对全局模型的贡献,并结合了在服务器聚合期间的优化技术。我们提出的注意聚合方法通过迭代更新参数并关注服务器模型与客户端模型之间的距离来最小化服务器模型与客户端模型之间的加权距离。在两个流行的语言建模数据集和一个社交媒体数据集上的实验表明,在大多数比较设置下,我们的提出的算法在困惑度和通信成本方面均优于其同类算法。
2.现有方法的局限性与不足
- 简单平均法的不足:现有的联邦学习算法在聚合客户端模型时仅采用简单的平均方法,忽略了不同客户端对全局模型的重要性。
案例:在某些情况下,某些客户端的模型可能比其他客户端更优秀,但简单平均法无法识别并充分利用这些高质量的模型。 - 缺乏优化机制:在中央服务器上的模型聚合过程中没有针对学习一个泛化性能良好的全局模型进行优化。
案例:这可能导致最终的全局模型不能很好地适应不同用户的个性化需求,影响预测效果。
3.理论
- 联邦学习
定义:联邦学习是一种分布式机器学习方法,它允许模型在多个分散的客户端(例如移动设备)上进行训练,而无需将数据集中到一个中心服务器,从而保护用户隐私并减少数据传输成本。
在论文中的应用:本文利用联邦学习来解决移动键盘建议这一语言建模问题,在不收集用户数据的情况下通过聚合来自不同用户的本地模型来训练全局模型。 - 注意力机制
定义:注意力机制是一种模拟人类视觉注意力的方法,它使模型能够专注于输入数据中最相关部分,从而提高处理效率和准确性。在神经网络中,它通常用于加权输入特征或层间连接。
在论文中的应用:文中引入了基于参数的层间注意力机制,用于量化每个客户端模型对全局模型的重要性,并通过优化这些权重来改进联邦学习中的模型聚合过程。
4.算法步骤
4.1 客户端模型更新
- 随机选择客户端:从所有存在的客户端中随机选择一部分参与当前轮次的训练。
- 本地训练:选中的客户端被恢复到全局模型,并在其本地数据上执行若干次迭代更新(如使用梯度下降法),然后将更新后的参数返回给服务器。
4.2 服务器端优化
- 初始化参数模型:设置初始模型参数。
- 计算注意力分数:对于每个客户端模型的每一层参数,计算其与服务器模型对应层参数之间的差异,并通过softmax函数转换为注意力分数
$$
\alpha_{k l} = \text{softmax} \left( s_{i l} \right) = \frac{\exp \left( s_{i l} \right)}{\sum_{k} \exp \left( s_{k l} \right)}
$$
- 更新全局模型:利用加权平均的方式更新全局模型参数,其中权重由上述计算得到的注意力分数决定。
$$
\theta_{t+1}^{l} - \theta_{t}^{l} = \sum_{k} \alpha_{k l} \left( \theta_{k t+1}^{l} - \theta_{t}^{l} \right)
$$
4.3 全局模型参数更新公式
$$
\theta_{t+1} \leftarrow \theta_{t} + \sum_{k=1}^{m} \alpha_{k} \left( \theta_{k t+1} - \theta_{t} \right) + \beta \mathcal{N}(0, \sigma^{2})
$$
注:此公式用于更新全局模型参数。
- $$\theta_{t}$$ 为可能在 $$t+1$$ 轮的全局模型参数,
- $$\theta_{t}$$ 为可能在 $$t$$ 轮的全局模型参数,
- $$\theta_{t}$$ 为服务器优化的学习率(步骤上),
- $$m$ 为场景合的客户端数,
- $$\alpha_{k}$$ 为由公式(l) 计算得到的注意力权重,衡量其重要性,
- $$\theta_{k t+1}$$ 为第k个客户端在round后第t+1轮的本地模型参数,
- $$\beta$$ 为控制噪声范围超参数,控制随机扰动的范围,
- $$N(0, \sigma^{2})$$ 为均值为0,方差为标准正太噪声。
公式通过考虑客户端和伊的噪音扰动,以及添加噪声保护隐私,从而更新到全局模型参数,使其更接近于客户端模型并在参数空间中更好地表示联邦客户端。
5.代码框架
fedatt.py
1 | import fedatt_algorithm |
fedatt_server.py
1 | from plato.servers import fedavg |
fedatt_algorithm.py
1 | from collections import OrderedDict |
对于该算法的相关代码如下:
- 定义 Algorithm 类,继承自 fedavg.Algorithm。
- 在构造函数 init 中初始化了 self.model_weights。
- 定义 aggregate_weights 方法,用于实现带注意力机制的聚合策略:
- 计算客户端权重与基准权重之间的差异。
- 初始化 att_update 为所有权重更新的零值。
- 计算每层权重的注意力分数 atts,这些分数是通过计算每个差异的L2范数并对其执行 softmax 得到的。
- 基于注意力分数通过加权组合来更新权重。
- 添加噪声机制,包括步长 epsilon 和噪声量级 magnitude。
- 返回更新后的权重 att_update。