
摘要
在多个领域中,数据对象可以被分解为一组更简单的对象。因此,很自然地将每个对象表示为其组成部分或部件的集合。许多传统的机器学习算法无法处理这种表示形式,因为集合的基数可能不同,且元素缺乏有意义的顺序。本文提出了一种新的神经网络架构,称为RepSet,该模型能够处理以向量集形式表示的样本。所提出的模型通过求解一系列网络流问题来计算输入集合与某些隐藏集合之间的对应关系。然后将这种表示形式输入到标准的神经网络架构中以生成输出。该架构支持端到端基于梯度的学习。我们在分类任务上展示了RepSet的应用,包括文本分类和图分类,并表明所提出的神经网络在性能上优于或可与现有最先进算法相媲美。
代码仓库
giannisnik/repset
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| document-classification-on-amazon | ApproxRepSet | Accuracy: 94.31 |
| document-classification-on-bbcsport | ApproxRepSet | Accuracy: 95.73 |
| document-classification-on-classic | ApproxRepSet | Accuracy: 96.24 |
| document-classification-on-recipe | ApproxRepSet | Accuracy: 59.06 |
| document-classification-on-reuters-21578 | ApproxRepSet | Accuracy: 97.17 |
| document-classification-on-twitter | ApproxRepSet | Accuracy: 72.6 |
| graph-classification-on-imdb-b | ApproxRepSet | Accuracy: 71.46% |
| graph-classification-on-imdb-m | ApproxRepSet | Accuracy: 48.92% |
| graph-classification-on-mutag | ApproxRepSet | Accuracy: 86.33% |
| graph-classification-on-proteins | ApproxRepSet | Accuracy: 70.74% |
| graph-classification-on-reddit-b | ApproxRepSet | Accuracy: 80.3 |
| text-classification-on-20news | ApproxRepSet | Accuracy: 76.18 |
| text-classification-on-ohsumed | ApproxRepSet | Accuracy: 64.06 |