Page 118 - 《振动工程学报》2025年第8期
P. 118
1758 振 动 工 程 学 报 第 38 卷
的向量场,并最终转变为输出变量。在这个过程中,
依靠在常微分方程求解器中选择步长,可以控制每 2 神经常微分方程故障诊断模型的建立
一层函数计算的数量。图 1 中五个箭头相当于具有
五个隐藏层的网络。标记为 x t 的虚线路径是变量通 由于神经常微分方程可以被视为连续化的残差
为 NODE 的最终输出。由于 网络,为了更好地说明其在故障诊断方面的优势,分
过 NODE 的轨迹 , x t N
神经网络的参数 θ 不随变量在向量场中的传递而改 别建立残差网络和神经常微分方程网络两个模型进
变,为了简便,网络将使用 f ( x t,t )取代 f ( x t,t,θ )。 行对比。残差网络(ResNet)模型共有三层结构,分
以上为 NODE 的前向传递过程。为了在反向 别为下采样卷积层、残差网络层和前馈层。如图 2
传递中节省内存成本,并且避免引入额外的数值计 所示,其下采样层用来接收输入数据,进而通过设置
算误差,采用伴随灵敏度方法,舍弃传统的反向传递 残差块(ResBlock)中第一个 Conv2D 层的步幅来提
算法。其核心思想是采用伴随灵敏度方法取代传统 取特征并降低数据的维数。下采样层的输出作为第
的梯度传递,而不再通过前向传递中的常微分方程 一个 ResBlock 的输入,之后依次传递给下一个 Res‑
求解器。具体过程如下,若将前向计算的损失函数 Block。像这样连接 n 个 ResBlock,最终 ResBlock 的
记为 L,则整个常微分方程求解器的前向损失为: 输出被展平并通过前馈网络,该网络将输出七种分
( t 1 ) 类标签。表 1 为所用的 ResNet 架构的基本布局以
L( x t 1) = L x t 0 ∫ f ( x t,t,θ ) dt (4)
+
t 0 及每一层输出张量的大小。表 2 指出架构中使用的
其伴随状态的定义为: ResBlock 的构成。ResNet 模型的超参数选择如表 3
∂L 所示。
α(t) = (5)
∂x t
式中, α( t )表示在变量 t 下的关于隐藏状态 x 的伴随
状态。
在 NODE 反向传递的过程中使用伴随灵敏度
法, α( t ) 为损失 L 对 x t 的导数。可以在连续时间下
通过下式将梯度继续向前传递:
dα( t ) ∂f ( x t,t,θ ) (6) 图 2 残差网络模型框架图
dt =-α( t ) ∂x Fig. 2 ResNet model framework diagram
t 0 ∂f ( x t,t,θ )
∫ dt (7) 表 1 残差网络的模型架构及每层输出张量
α( t 0 )= α( t 1 )- α( t )
∂x t
Tab. 1 Model architecture and output tensor per layer
t 1
将上式从 t 1 伴随状态到 t 0 伴随状态的计算过程 of ResNet
拓展至整个反向传递网络,计算此过程的整个梯度:
网络层 层架构 输出
dL t 0 ∂f ( x t,t,θ ) dt Conv2D (64,64,1020,1)
dθ = ∫ α( t ) ∂θ (8)
下采样层 ResBlock (64,64,510,1)
t 1
以上为 NODE 使用伴随灵敏度法完成反向传
ResBlock (64,64,255,1)
递的过程,由于此方法不需要保存反向传递过程中 残差层 ResBlock×6 (64,64,255,1)
产生的导数梯度,只需记录最后一个时间步的隐藏 GroupNorm (64,64,255,1)
状态,而神经网络需保存导数梯度、权重和偏置等参 ReLU (64,64,255,1)
数来完成导数链式法则的传递,故减少了不必要的 前馈层 Adaptive AvgPool2D (64,64,1,1)
内存消耗。该方法之所以可以替代神经网络的反向 Flatten (64,64)
传递,是在前向传递中使用了神经常微分方程,计算 Linear (64, 7)
出了每层的微分形式表示的隐藏状态。而且选择不 表 2 残差块的构成
同的常微分方程求解方法,更有利于平衡计算时间 Tab. 2 Composition of ResBlock
和准确率。
层架构 输出
基于此,许多神经网络可以被微分方程所替换, GroupNorm (64,64,255,1)
利用常微分方程求解器取代各神经元及各层之间复 ReLU (64,64,255,1)
杂的计算过程,避免盲目堆叠层数导致的浪费,可以 Conv2D (64,64,255,1)
有效减少参数数量,实现参数共享,显著减少内存 GroupNorm (64,64,255,1)
成本 [16] 。 Conv2D (64,64,255,1)

