深度学习在表格数据中的挑战

深度学习已在计算机视觉、自然语言处理等领域引发革命,但表格数据领域仍由经典机器学习算法(如梯度提升)主导。直觉上,神经网络作为通用近似器,理论上应能处理表格数据,但实际效果不及梯度提升树。这可能与决策树的归纳偏置更适合表格数据有关。

可微分决策树的突破

2015年,Kontschieder等人提出深度神经决策森林,通过将决策节点的严格二元路由松弛为概率化(使用Sigmoid函数),实现了决策树的可微分性。具体而言:

  • 叶节点:替换为Softmax层,输出类别分布。
    • 决策节点:使用Sigmoid函数计算样本向左/右路由的概率,通过路径概率乘积得到叶节点到达概率,最终预测为所有叶节点的加权平均。

神经遗忘决策树(NODE)

NODE基于对称生长的遗忘树(Oblivious Tree),每层使用相同特征进行分裂。其核心创新包括:

  1. 特征选择:采用α-entmax替换Softmax,实现稀疏特征选择(学习矩阵F)。
    1. 阈值松弛:将不可微的Heaviside函数替换为可微的双面α-entmax,并引入可学习的尺度参数b
    1. 响应张量:通过外积生成路径选择权重,最终输出为响应张量的加权和。

深度NODE架构

通过堆叠多个NODE层(带残差连接)构建深度模型:

  • 每层输入为前一层的输出与原始特征的拼接。
    • 最终预测为各层输出的平均。

实验与结果

在Epsilon、Higgs等6个数据集上,NODE与CatBoost、XGBoost和全连接神经网络对比:

  • 默认参数:NODE(单层2048棵树,深度6)表现优于传统方法。
    • 调参后:NODE在多数任务中保持领先。

实现与工具

  • 官方实现:基于PyTorch的模块化代码库。
    • 集成库:支持在PyTorch Tabular中直接调用NODE及其他表格数据算法。

参考文献

  1. Kontschieder et al., Deep Neural Decision Forests (ICCV 2015).
    1. Peters et al., Sparse Sequence-to-Sequence Models (ACL 2019).
    1. Popov et al., Neural Oblivious Decision Ensembles (arXiv:1909.06312).
  2. 更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)