Page 192 - 《软件学报》2021年第12期
P. 192

3856                                Journal of Software  软件学报 Vol.32, No.12, December 2021

                 矩阵.对推荐标签数量较少或点击率预测任务(二分类),使用联邦蒸馏方法可大大减少上传参数的体
                 量,缓解大规模设备下可能造成的通信拥堵;
             (3)  联邦中心使用联邦学习算法将接收到的每台设备上传的标签平均 Logits 向量整合为新的全局 Logits
                 向量.具体地,针对每台设备,联邦中心将其他设备发送的 Logits 向量使用联邦学习算法构建出该台设
                 备的教师模型,并将教师模型分发到每台设备中(该步骤具体流程详见表 3);
             (4)  设备接收教师模型,通过结合自适应学习率策略(见第 2.5 节)优化联合损失函数(见第 2.3 节),并以此
                 指导学生网络的训练.联合损失函数包含教师网络、学生网络的损失,同时还包含教师网络与学生网
                 络之间的差异度.该步骤算法流程详见表 2.
             以上描述中,步骤(1)和步骤(3)中的推荐算法和联邦学习算法不限,可根据实际需求自由组合.在下面的章
         节,我们将详细描述图 1 流程及表 2、表 3 算法中使用的策略.
                            Table 2  Attentional federated distillation—Processes on devices
                                    表 2   注意力联邦蒸馏算法——设备流程
                         输入:设备集K,标签集T,设备模型迭代训练轮数E;
                                                         k
                         1   随机初始化每台设备每个标签的Logits  S
                                                         t
                         2   for k in K do
                                                               / k
                                                                   / k
                         3      接收联邦中心分发的教师模型Logits  S     / k  = {S 1 ,...,S | | }
                                                                   T
                         4        for e←1 to E do
                                    k
                         5          e l ← Eq.(1)    //计算学生模型损失
                         6          GL(e,k)←Eq.(3,4)   //计算联合损失
                         7         η←Eq.(20)   //自适应学习率策略选择合适学习率
                         8           e ω  k  e ω ←  k  η  ∇ −  GL (, )e k
                         9         for t in T do   //计算设备k每个标签的Logits
                                       k
                         10            t S ← Eq.(5)
                         11        end
                         12     end
                         13     for t in T do   //计算设备标签的平均Logits
                                                           k
                         14        统计本设备标签为t的训练数据量 n
                                                           t
                                           k
                                     k
                         15          t S ←  t S k  / n
                                           t
                         16     end
                                                          k
                                                  k
                                                      k
                         17        发送设备k的平均Logits  S = {S 1 ,...,S | | } 到联邦中心
                                                          T
                         18   end
                        Table 3  Attentional federated distillation—Processes on the federated center
                                   表 3   联邦注意力蒸馏算法——联邦中心流程
                     输入:设备集K,标签集T,全局模型最大迭代轮数MaxEpoch;
                                                     /k
                     1   随机初始化每台设备除该设备外的Logits S
                     2   for epoch←1 to MaxEpoch do
                     3      for k in K do
                     4         接收设备k的平均Logits  S = {S 1 ,...,S | | }
                                                    k
                                                 k
                                                        k
                                                        T
                     5         for t in T do
                     6            t S ←  k  t S +  k  t S  / k     //累加其他设备标签t的Logits
                     7         end
                     8      end
                     9      //计算教师网络Logits并分发
                     10     for k in K do
                     11        for t in T do
                     12           t S  / k  ←  Eq.(6)
                     13        end
                     14        分发教师模型Logits  S  / k  = {S 1 ,...,S | | } 到设备k
                                                       / k
                                                  / k
                                                       T
                     15     end
                     16   end
   187   188   189   190   191   192   193   194   195   196   197