Page 193 - 《软件学报》2021年第12期
P. 193
谌明 等:一种基于注意力联邦蒸馏的推荐方法 3857
2.3 联邦蒸馏
现有的联邦学习算法是对模型权重进行平均,由于推荐系统中模型复杂,权重参数众多,分配到每台设备
上,模型参数回传到联邦中心,会占用大量的资源,并且联邦中心计算权重平均值也是一笔巨大的时间开销.当
采用现有联邦蒸馏算法的损失函数进行优化时,仅仅分别计算了教师网络和学生网络与真实标签的误差值,却
忽略了教师网络和学生网络本身的差异性给模型带来了影响,容易造成模型过拟合.通过实验发现,教师网络和
学生网络本身的差异性对模型的推荐效果具有较大的影响.为了减少学生网络和教师网络之间差异大造成的
影响,本文提出了一种新的目标函数.相比于传统目标函数只计算本地设备预测值与真实值之间的误差,本文提
出的目标函数除了利用其他设备作为教师模型来指导本地学生模型训练,还将学生模型与教师模型之间的差
别作为优化目标的一部分加入损失函数,降低设备间数据差异造成的影响.
首先,设备 k(k∈K)的本地学生模型及联邦中心分发的教师模型在该设备上的损失函数可分别定义为
k
k
k
l =f(p ,y ) (1)
k
k
L = ( fp teacher , y k ) (2)
k
k
k
其中,f(⋅)为损失函数,p 和 y 分别为设备 k 中学生模型对本地测试数据的预测值及其真实值, p teacher 为教师模型
对本设备测试数据的预测值.假设全局模型共需训练 E 轮,则训练 e 轮(e∈[1,E])后的联合损失函数由学生模型损
失、教师模型损失以及学生模型与教师模型差异组成(表 2 第 6 行),具体定义如下:
⎧ k λ k 2
⎪ l α e + ||ω || , if e = 1
⎪ 2
GL (, )e k = ⎨ k k λ k 2 (3)
−
β+
−
l α
L
( ||l
L
)KL
⎪ N e + β N e (1 α ) + 2 ||ω || , if e > 1
e
e
⎪学生模型损失 教师模型损失 学生模型与教师模型差异
⎩ 正则项
k
其中,α,β分别为学生模型和教师模型损失的权重参数,λ为正则项权重参数,ω 为设备 k 的模型参数(如神经网络
2
中的 Weights 和 Bias),||⋅|| 为 L2 范数.为节省参数通信量(传统联邦学习算法如 FedAvg 需传输模型参数)并增强
模型的泛化性能,本文方法在联合损失函数中增加了 L2 正则项(见公式(3)).由于高度偏斜的非独立同分布(non-
IID)数据会让学生模型之间的分布差异增大,降低整个模型的收敛效率,本文通过使用 KL 散度(Kullback–
Leibler divergence)来衡量学生模型和教师模型之间的差异,并将该差异作为全局损失函数的一部分进行优化.
差异计算方式如下:
|| K l k
(|| L = ∑
KL l k e k e ) l e k log e k (4)
k = 1 L e
公式(3)中,当 e=1 时(即第 1 轮全局模型训练),此时联邦中心尚未收集首轮本地设备的模型 Logits,本地设备
无需从联邦中心接受教师模型的 Logits,此时,联合损失仅包含本地学生模型的损失;当 e>1 时,联邦中心已完成
首轮模型收集并分发教师模型,则本地学生模型的优化可同时使用学生模型、教师模型及学生-教师模型差异
进行联合优化.同时,为加速模型收敛速度,本文提出一个可自动切换优化算法及选择合适学习率的优化策略,
用于优化联合损失函数(见第 2.5 节).优化后的本地学生模型对本地设备数据进行预测,得出新本地模型对应
ˆ k
每个数据标签的 Logits S ,并通过下式更新设备 k 对应标签 t 的 Logits(表 2 第 10 行):
t
S = k S + ˆ k S / k (5)
t t t
其中, S / k 为从联邦中心接受到的去除第 t 台设备后的 Logits 平均值(即教师模型).最后,设备 k 将平均后的本地
t
Logits 作为设备 k 的学生模型发送到联邦中心进行整合.联邦中心通过计算除去每台设备本身的其他设备学生
模型的 Logits 的平均值来得到该设备的教师模型,并分发给对应设备(表 3 第 16 行).具体计算方式如下:
k
S − S / k
S t / k = t t (6)
| K |
联邦蒸馏的过程减少了传统联邦学习过程中的模型权重回收和分发造成的时间和通信开销,能够有效提
升整体效率.同时,通过加入 KL 散度,将教师模型和学生模型之间的差异性加入到损失函数中进行优化,从而缓