Page 192 - 《软件学报》2021年第12期
P. 192
3856 Journal of Software 软件学报 Vol.32, No.12, December 2021
矩阵.对推荐标签数量较少或点击率预测任务(二分类),使用联邦蒸馏方法可大大减少上传参数的体
量,缓解大规模设备下可能造成的通信拥堵;
(3) 联邦中心使用联邦学习算法将接收到的每台设备上传的标签平均 Logits 向量整合为新的全局 Logits
向量.具体地,针对每台设备,联邦中心将其他设备发送的 Logits 向量使用联邦学习算法构建出该台设
备的教师模型,并将教师模型分发到每台设备中(该步骤具体流程详见表 3);
(4) 设备接收教师模型,通过结合自适应学习率策略(见第 2.5 节)优化联合损失函数(见第 2.3 节),并以此
指导学生网络的训练.联合损失函数包含教师网络、学生网络的损失,同时还包含教师网络与学生网
络之间的差异度.该步骤算法流程详见表 2.
以上描述中,步骤(1)和步骤(3)中的推荐算法和联邦学习算法不限,可根据实际需求自由组合.在下面的章
节,我们将详细描述图 1 流程及表 2、表 3 算法中使用的策略.
Table 2 Attentional federated distillation—Processes on devices
表 2 注意力联邦蒸馏算法——设备流程
输入:设备集K,标签集T,设备模型迭代训练轮数E;
k
1 随机初始化每台设备每个标签的Logits S
t
2 for k in K do
/ k
/ k
3 接收联邦中心分发的教师模型Logits S / k = {S 1 ,...,S | | }
T
4 for e←1 to E do
k
5 e l ← Eq.(1) //计算学生模型损失
6 GL(e,k)←Eq.(3,4) //计算联合损失
7 η←Eq.(20) //自适应学习率策略选择合适学习率
8 e ω k e ω ← k η ∇ − GL (, )e k
9 for t in T do //计算设备k每个标签的Logits
k
10 t S ← Eq.(5)
11 end
12 end
13 for t in T do //计算设备标签的平均Logits
k
14 统计本设备标签为t的训练数据量 n
t
k
k
15 t S ← t S k / n
t
16 end
k
k
k
17 发送设备k的平均Logits S = {S 1 ,...,S | | } 到联邦中心
T
18 end
Table 3 Attentional federated distillation—Processes on the federated center
表 3 联邦注意力蒸馏算法——联邦中心流程
输入:设备集K,标签集T,全局模型最大迭代轮数MaxEpoch;
/k
1 随机初始化每台设备除该设备外的Logits S
2 for epoch←1 to MaxEpoch do
3 for k in K do
4 接收设备k的平均Logits S = {S 1 ,...,S | | }
k
k
k
T
5 for t in T do
6 t S ← k t S + k t S / k //累加其他设备标签t的Logits
7 end
8 end
9 //计算教师网络Logits并分发
10 for k in K do
11 for t in T do
12 t S / k ← Eq.(6)
13 end
14 分发教师模型Logits S / k = {S 1 ,...,S | | } 到设备k
/ k
/ k
T
15 end
16 end