Page 13 - 《软件学报》2024年第4期
P. 13

刘鑫  等:  基于多样真实任务生成的鲁棒小样本分类方法                                                     1591


         K-shot 分类任务上进行测试,  认为:  如果一个方法可在 N-way K-shot 分类任务上取得好的结果,  该方法就具备
         小样本学习能力.  一个典型的 N-way  K-shot 分类任务中包括两部分数据:  支持集(support set)和查询集(query
         set).  支持集对应于传统机器学习中的训练样本,  包含少量的有标签训练数据,  包括 N×K 个样本,  其中, N 为需
         要分类的类别数目, K 为每个类中的样本数目,  通常设为 1 或 5.  查询集对应于传统机器学习中的测试样本,  为
         待分类的无标签数据,  用于验证方法通过支持集样本对 N 个类的认知能力,  包括 N×Q 个样本.查询集的类别与
         支持集的类别一致,  只是每个类的样本数目为 Q 个.  N-way  K-shot 分类任务就是通过在 N×K 个支持集样本上
         学习一个模型,  用学到的模型对 N×Q 个查询样本进行分类.
             但是直接利用 N×K 个样本进行学习,  十分容易出现过拟合现象.  为了解决这一问题,  目前的小样本分类
         方法通常借助于一个有大量有标签数据的辅助数据集.  先在该数据集上学习一些知识,  利用这些知识帮助目
         标小样本任务进行学习.  在小样本学习中,  把辅助数据集中的类别叫做基类(base class),  测试小样本任务中的
         类别叫做新类(novel class).  通常,  基类和新类的类别是不相交的.  基于元学习的小样本分类方法希望通过学
         着去学习的方式来进行学习,  通过在辅助数据集上构造一系列不同的 N-way  K-shot 分类任务,  学习解决这类
         任务的元知识,  利用元知识帮助目标任务进行学习,  如图 2 所示.  一般把这样的训练方式叫做插曲式训练机制
         或者元学习训练机制.  测试时,  为了验证模型可以适用于不同的小样本分类任务,  通常在测试数据集上构造
         一系列 N-way K-shot 任务进行测试,  计算在这些任务上的平均性能作为模型的评价指标.













                                      图 2    元学习训练和测试机制示意图
         2.2   MAML和原型网络

             在基于元学习的小样本分类方法中,  通常我们会构造一系列元训练任务 { }                        M 1  ,  每个任务 i 中包含着两
                                                                        
                                                                         ii=
                                                       ×
                                                      }
                                                  xy
                                 }
                                   ×
                             xy
                              j
                                                   j
         部分数据:  支持集     i S  =  {,  i j NK  和查询集  i Q  = {,  i j NQ  .  模型 f f 在数据集上的损失记为A(f f ;).
                                                       j=
                              i
                                  j=
                                   1
                                                        1
                                                  i
                   [6]
                                                                  *
             MAML 的目标是:  学习一个适用于不同任务的好的初始化参数θ ,  在新的任务上,  少量支持集样本                              S
         只需要更新一步或几步就可以达到当前任务的最优点. MAML 可以形式化为如下双层优化问题:
                                            θ  *  argmin ; ( ),θ =  
                                           1      θ                                         (1)
                                                  ( lg( f
                                     ;  ( )θ =  ∑ M  A �  f MAML ; i S  ); i Q )
                                           M   i= 1               
                                    S
                                                                                lg( f
                lg( f
         其中,  �   f MAML ; i S ) 表示利用 对任务 i 进行优化更新.  以一步梯度下降更新为例,  �          f MAML ; i S  ) 可写成
                      = −
           lg( f
                             [ ( f
          �  f  MAML ; i S  ) θ η∇ A  θ  MAML ; S  )] , η是学习率, M 是元训练任务的数量.
                            θ
                    [7]
             原型网络 要学习的元知识是一个低维度量空间,  在该空间中查询样本可以根据计算与每个类原型的距
         离进行分类.  给定每个类的少量支持集样本,  第 r 的类原型 c r 为该类所有支持集样本在度量空间特征表示
          f θ PN ()x s  的均值.  原型网络的损失函数如下:
                                                                             c
                                                                     (
                                                                   −
                        1   M                    1   M       exp( df  PN  (x q  ), ))  
                      =   ∑    −A  log ( p y q  =  | rx q , i k  ∑  ) =  ∑  1 ∑    −  log  θ  , i k  r       (2)
                                                                     −
                        M   i= 1  i ,, k r  , i k    M  i=    i ,, k r  ∑  r′  exp( df θ PN  (x q , i k ),c r′  ))    
                                                                       (
         其中,  f θ PN  (x q , i k  ) 是第 i 个任务中第 k 类的查询样本经过原型网络得到的隐表示, M 是元训练任务的数量.
   8   9   10   11   12   13   14   15   16   17   18