简介
随着机器学习与深度学习的发展,DNN
在图像、语音和文本类型的数据上都有了广泛的应用,其优势如下:
- 类似图像、文本,通过对其数据进行编码,从而得到一个表征表格数据。
- 减少对特征工程的依赖(参加过
kaggle
竞赛的同学都知道特征在模型中的重要性)。 - 可通过
online learning
的方式更新模型。
而在表格类数据的任务中,大部分都是由决策树模型成为标配,其优势如下:
- 可以根据决策树回溯其推理过程,可解释性较强。
- 决策树的流形可以看作是超平面的边界,对于表格类数据的效果很好。
- 训练速度很快。
因此如果有一种模型即吸收决策树模型的可解释性和稀疏特征选择优点,也具有 DNN
的 end-to-end
长处,那毫无疑问就是针对表格数据的利器 - TabNet
。
整体架构
TabNet
整体架构如下图所示:
- 特征首先通过
BatchNorm
层,才能作为其他阶段的输入。 - 网络中存在重复的结构
Step
,各个Step
的输入均是BN
层的输出。 Step
中包含两个最重要的组件Feature transformer
与Attentive trandformer
,该两个组件均有BN
层、FC
层组成,其内部的堆叠个数可作为超参数进行配置。Feature transformer
特征处理,其输出经过一个分裂Split
部分,一部分经由ReLU
之后得到输出;另一部分作为Attentive transformer
的输入。Attentive transformer
权重系数特征选择,其输出得到Mask
稀疏掩码矩阵。- 各个
Step
输出求和后,同时经过FC
层变换映射得到模型的最终输出。 Mask
稀疏掩码矩阵输出累加之后构成Feature attributes
,其用于模型的可解释性。
具体模块
特征处理
特征处理组件 Feature transformer
其被设计成串行的方式,网络结构如下图所示:
Shared across decision steps
各个Step
共享层,即在各个Step
中都是被共享的。Decision step dependent
各个Step
独立层,即在每层需要单独训练。
通过图中的网络结构可以将 FC
、BN
、GLU
看作一个 block
。
其中 GLU(Gated Linear Unit)
是广义线性单元,其完成的变换如下:$GLU(a,b)=a \otimes \sigma (b)$,即输入两个相同 shape
的矩阵,b
矩阵经 $\sigma$ 变换后,与 a
矩阵中对应的元素相乘。GLU
层的源码如下:
1 | class GLU_Layer(torch.nn.Module): |
各个 block
通过跳层连接 skip-connection
后,再乘以 $\sqrt{0.5}$,用于防止网络中输出的方差变换波动太大,影响学习的稳定性。
最后 Feature transformer
在完成计算后需要经过 Split
分裂成两部分,其分别是:
d[i]
:表示第i
个Step
上用于最终的模型输出。输出维度为 $R^{B*N_d}$,B
为BatchSize
,$N_d$ 为输出维度,经过ReLU
后,用于最终的模型输出。a[i]
:表示第i
个Step
上用于Attentive transformer
的输入。输出维度为 $R^{B*N_a}$,用于Attentive transformer
的输入。
特征选择
TabNet
的特征选择过程主要为每个 Step
学习一个 Mask
掩码矩阵,记为 $M[i] \in R^{B*D}$,其保持跟输入特征一样的 shape
。其网络结构如下:
Mask
的计算公式:
网络中存在一个加权缩放因子Prior scales
,记为P
,其shape
与输入特征一致。
当在i
时刻,缩放因子与上一时刻Feature transformer
的输出a[i-1]
的关系为:
$$ M[i] = sparemax(P[i-1] * h_i(a[i-1])) $$
上式中的sparemax
是归一化操作,类似于softmax
,不过其可以得到更稀疏的输出结果,即只选择最显著的特征,以此达到特征选择的目的。
$h_i(a[i-1]))$ 指图中的FC
、BN
层。缩放因子的计算公式:
$$ P[i] = _{j=1}^i (\gamma - M[j]) $$
由上述公式可知,对于第i
步中的缩放因子与j=0,1...,i
步内的Mask[j]
有关。
公式中的 $\gamma$ 在不同的取值情况下会有不同的效果:- 当 $\gamma > 1$ 时,由于
M[j]
内的取值在[0,1]
内且取值集中在0
或1
附近,特征在后续的Step
中可以再次被以更高的权重使用。 - 在
Step = 0
时刻 $P[0] = 1^{B * D}$,图中省略了该Mask
。
- 当 $\gamma > 1$ 时,由于
Mask
的稀疏性:
为了保证Mask
的稀疏性,需要对M[i]
中要学习的参数进行正则化约束。其约束的方式为:对于网络中所有Step
内涉及的Mask
的参数 $M_{b,j}$,以如下公式求和作为损失函数的一部分:
$$ L_{sparse} = sum_{i=i}^N sum_{b=1}^B sum_{j=1}^D {\frac {-M_{b,j}[i]} {N * B}} log(M_{b,j}[i] + \epsilon) $$
自监督学习
DNN
可以进行表征学习,而 TabNet
应用自监督学习,通过 encoder-decoder
框架来获得表格数据的 representation
,从而也有助于分类和回归任务。如下图所示:
简单来说即认为同一样本之间不同特征是具有关联的,因此自监督学习则是先人为 Mask
一部分特征,然后通过 encoder-decoder
模型来对 Mask
的特征进行预测。
自监督学习的 encoder
部分如上图所示,而 decoder
部分如下所示:
上面的 Encoded representation
就是 encoder
未经 FC
层的加和向量。
将其作为 decoder
的输入,decoder
同样利用 Feature transformer
,其将 representation
向量重构为 Feature
,之后经过若干个 Step
的加和,得到最后的重构 Feature
。
最终输出
TabNet
的最终输出先对各个 Step
中 Feature transformer
的输出进行 ReLU
变换,最终对所有的 Step
中的输出进行求和,经过全连接层,得到最终的输出。即:
$$
\begin{equation}
\begin{split}
d_{out} = sum_{i=1}^N ReLU(d[i]) \\\\
output = FC(d_{out}) = W_{final} d_{out}
\end{split}
\nonumber
\end{equation}
$$
可解释性
结合 TabNet
的网络结构以及中间的组件 Feature transformer
和 Attentive transformer
,从这些可以很明显的看出与 MLP
等神经网络的差异。MLP
神经网络中各层对输入特征均是无区别对待,而 TabNet
则是通过 Mask
矩阵来体现各个特征在不同 Step
中的重要性。
示例
1 | from pytorch_tabnet.tab_model import TabNetClassifier |
总结
多关注前沿机器学习的内容,会对工作有很大的启发!
引用
TabNet: Attentive Interpretable Tabular Learning
Github 源码
个人备注
此博客内容均为作者学习所做笔记,侵删!
若转作其他用途,请注明来源!