3 个月前

Hopular:面向表格数据的现代霍普菲尔德网络

Hopular:面向表格数据的现代霍普菲尔德网络

摘要

尽管深度学习在视觉和自然语言处理等结构化数据任务中表现出色,但在表格数据(tabular data)上的表现却未能达到预期。对于表格数据,支持向量机(SVM)、随机森林(Random Forests)以及梯度提升(Gradient Boosting)是性能最优的技术,其中梯度提升位居前列。近年来,针对表格数据设计的深度学习方法不断涌现,但在小规模数据集上仍难以超越梯度提升算法的性能。为此,我们提出了一种名为“Hopular”的新型深度学习架构,专为中等及小规模数据集设计。该架构的每一层均配备了连续型现代霍普菲尔德网络(continuous modern Hopfield networks)。这些现代霍普菲尔德网络能够利用存储的数据,识别特征-特征、特征-目标以及样本-样本之间的依赖关系。Hopular的创新之处在于:每一层均可通过霍普菲尔德网络中存储的信息,直接访问原始输入以及整个训练集。因此,Hopular能够在每一层中像标准的迭代学习算法一样,逐步更新当前模型及其预测结果。在包含少于1,000个样本的小规模表格数据集上的实验表明,Hopular在性能上超越了梯度提升、随机森林、支持向量机,以及多种现有的深度学习方法。在中等规模表格数据集(约10,000个样本)上的实验进一步显示,Hopular在性能上优于XGBoost、CatBoost、LightGBM,以及一种针对表格数据设计的先进深度学习方法。因此,Hopular为表格数据建模提供了一种极具竞争力的替代方案,展现出在中、小规模数据场景下的显著优势。

代码仓库

ml-jku/hopular
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
general-classification-on-shrutime CatBoost
Accuracy: 86.39 ± 0.04
general-classification-on-shrutimeNPTs
Accuracy: 85.62 ± 0.07
general-classification-on-shrutime LightGBM
Accuracy: 86.18 ± 0.02
general-classification-on-shrutimeHopular
Accuracy: 86.12 ± 0.09
general-classification-on-shrutimeXGBoost
Accuracy: 84.58 ± 0.00

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
Hopular:面向表格数据的现代霍普菲尔德网络 | 论文 | HyperAI超神经