3 个月前

逐步提升 Graph WaveNet 在交通预测中的性能

逐步提升 Graph WaveNet 在交通预测中的性能

摘要

我们提出了一系列改进措施,显著提升了Graph WaveNet在METR-LA交通预测任务上的先前最优性能。该任务的目标是基于过去一小时的传感器读数,预测交通网络中每个传感器未来时刻的车速。Graph WaveNet(GWN)是一种时空图神经网络,通过交替使用图卷积聚合邻近传感器的信息,以及使用空洞卷积捕捉历史信息来实现预测。我们对GWN进行了三项改进:(1)采用更优的超参数设置;(2)引入新的连接结构,使更大的梯度能够更有效地反向传播至早期卷积层;(3)在更简单的短期交通预测任务上进行预训练。这些改进使METR-LA任务上的平均绝对误差降低了0.06,降幅几乎相当于GWN相较于其前代模型的性能提升幅度。该改进在PEMS-BAY数据集上也表现出良好的泛化能力,相对提升幅度相似。此外,我们还证明了分别训练短期与长期预测模型并进行集成,可进一步提升整体性能。相关代码已开源,地址为:https://github.com/sshleifer/Graph-WaveNet。

代码仓库

sshleifer/Graph-WaveNet
官方
pytorch
GitHub 中提及
david-amirault/amhs
pytorch
GitHub 中提及
josegg05/eRGWnet
pytorch
GitHub 中提及
nnzhan/Graph-WaveNet
pytorch
GitHub 中提及

基准测试

基准方法指标
traffic-prediction-on-metr-laFinetune from t1-6 checkpoint
MAE @ 12 step: 3.47

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
逐步提升 Graph WaveNet 在交通预测中的性能 | 论文 | HyperAI超神经