Page 13 - 《软件学报》2024年第4期
P. 13
刘鑫 等: 基于多样真实任务生成的鲁棒小样本分类方法 1591
K-shot 分类任务上进行测试, 认为: 如果一个方法可在 N-way K-shot 分类任务上取得好的结果, 该方法就具备
小样本学习能力. 一个典型的 N-way K-shot 分类任务中包括两部分数据: 支持集(support set)和查询集(query
set). 支持集对应于传统机器学习中的训练样本, 包含少量的有标签训练数据, 包括 N×K 个样本, 其中, N 为需
要分类的类别数目, K 为每个类中的样本数目, 通常设为 1 或 5. 查询集对应于传统机器学习中的测试样本, 为
待分类的无标签数据, 用于验证方法通过支持集样本对 N 个类的认知能力, 包括 N×Q 个样本.查询集的类别与
支持集的类别一致, 只是每个类的样本数目为 Q 个. N-way K-shot 分类任务就是通过在 N×K 个支持集样本上
学习一个模型, 用学到的模型对 N×Q 个查询样本进行分类.
但是直接利用 N×K 个样本进行学习, 十分容易出现过拟合现象. 为了解决这一问题, 目前的小样本分类
方法通常借助于一个有大量有标签数据的辅助数据集. 先在该数据集上学习一些知识, 利用这些知识帮助目
标小样本任务进行学习. 在小样本学习中, 把辅助数据集中的类别叫做基类(base class), 测试小样本任务中的
类别叫做新类(novel class). 通常, 基类和新类的类别是不相交的. 基于元学习的小样本分类方法希望通过学
着去学习的方式来进行学习, 通过在辅助数据集上构造一系列不同的 N-way K-shot 分类任务, 学习解决这类
任务的元知识, 利用元知识帮助目标任务进行学习, 如图 2 所示. 一般把这样的训练方式叫做插曲式训练机制
或者元学习训练机制. 测试时, 为了验证模型可以适用于不同的小样本分类任务, 通常在测试数据集上构造
一系列 N-way K-shot 任务进行测试, 计算在这些任务上的平均性能作为模型的评价指标.
图 2 元学习训练和测试机制示意图
2.2 MAML和原型网络
在基于元学习的小样本分类方法中, 通常我们会构造一系列元训练任务 { } M 1 , 每个任务 i 中包含着两
ii=
×
}
xy
}
×
xy
j
j
部分数据: 支持集 i S = {, i j NK 和查询集 i Q = {, i j NQ . 模型 f f 在数据集上的损失记为A(f f ;).
j=
i
j=
1
1
i
[6]
*
MAML 的目标是: 学习一个适用于不同任务的好的初始化参数θ , 在新的任务上, 少量支持集样本 S
只需要更新一步或几步就可以达到当前任务的最优点. MAML 可以形式化为如下双层优化问题:
θ * argmin ; ( ),θ =
1 θ (1)
( lg( f
; ( )θ = ∑ M A � f MAML ; i S ); i Q )
M i= 1
S
lg( f
lg( f
其中, � f MAML ; i S ) 表示利用 对任务 i 进行优化更新. 以一步梯度下降更新为例, � f MAML ; i S ) 可写成
= −
lg( f
[ ( f
� f MAML ; i S ) θ η∇ A θ MAML ; S )] , η是学习率, M 是元训练任务的数量.
θ
[7]
原型网络 要学习的元知识是一个低维度量空间, 在该空间中查询样本可以根据计算与每个类原型的距
离进行分类. 给定每个类的少量支持集样本, 第 r 的类原型 c r 为该类所有支持集样本在度量空间特征表示
f θ PN ()x s 的均值. 原型网络的损失函数如下:
c
(
−
1 M 1 M exp( df PN (x q ), ))
= ∑ −A log ( p y q = | rx q , i k ∑ ) = ∑ 1 ∑ − log θ , i k r (2)
−
M i= 1 i ,, k r , i k M i= i ,, k r ∑ r′ exp( df θ PN (x q , i k ),c r′ ))
(
其中, f θ PN (x q , i k ) 是第 i 个任务中第 k 类的查询样本经过原型网络得到的隐表示, M 是元训练任务的数量.