深度学习在表格数据中的挑战
深度学习已在计算机视觉、自然语言处理等领域引发革命,但表格数据领域仍由经典机器学习算法(如梯度提升)主导。直觉上,神经网络作为通用近似器,理论上应能处理表格数据,但实际效果不及梯度提升树。这可能与决策树的归纳偏置更适合表格数据有关。
可微分决策树的突破
2015年,Kontschieder等人提出深度神经决策森林,通过将决策节点的严格二元路由松弛为概率化(使用Sigmoid函数),实现了决策树的可微分性。具体而言:
- 叶节点:替换为Softmax层,输出类别分布。
-
- 决策节点:使用Sigmoid函数计算样本向左/右路由的概率,通过路径概率乘积得到叶节点到达概率,最终预测为所有叶节点的加权平均。
神经遗忘决策树(NODE)
NODE基于对称生长的遗忘树(Oblivious Tree),每层使用相同特征进行分裂。其核心创新包括:
- 特征选择:采用
α-entmax
替换Softmax,实现稀疏特征选择(学习矩阵F
)。 -
- 阈值松弛:将不可微的Heaviside函数替换为可微的
双面α-entmax
,并引入可学习的尺度参数b
。
- 阈值松弛:将不可微的Heaviside函数替换为可微的
-
- 响应张量:通过外积生成路径选择权重,最终输出为响应张量的加权和。
深度NODE架构
通过堆叠多个NODE层(带残差连接)构建深度模型:
- 每层输入为前一层的输出与原始特征的拼接。
-
- 最终预测为各层输出的平均。
实验与结果
在Epsilon、Higgs等6个数据集上,NODE与CatBoost、XGBoost和全连接神经网络对比:
- 默认参数:NODE(单层2048棵树,深度6)表现优于传统方法。
-
- 调参后:NODE在多数任务中保持领先。
实现与工具
- 官方实现:基于PyTorch的模块化代码库。
-
- 集成库:支持在PyTorch Tabular中直接调用NODE及其他表格数据算法。
参考文献
- Kontschieder et al., Deep Neural Decision Forests (ICCV 2015).
-
- Peters et al., Sparse Sequence-to-Sequence Models (ACL 2019).
-
- Popov et al., Neural Oblivious Decision Ensembles (arXiv:1909.06312).
- 更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)