4 个月前

基于高斯过程和签名协方差的贝叶斯序贯数据学习

基于高斯过程和签名协方差的贝叶斯序贯数据学习

摘要

我们开发了一种基于高斯过程(Gaussian Processes, GPs)的贝叶斯方法,通过使用所谓的签名核作为协方差函数来从序列数据中学习。这使得不同长度的序列可以进行比较,并且可以依赖随机分析中的强大理论结果。签名通过张量捕捉序列结构,但这些张量在序列长度和状态空间维度上可能会不利地扩展。为了解决这一问题,我们引入了一种带有诱导张量的稀疏变分方法。然后,我们将得到的高斯过程与长短期记忆网络(LSTMs)和门控循环单元(GRUs)相结合,构建更大的模型,以充分利用每种方法的优势,并在多变量时间序列(Time Series, TS)分类数据集上对这些高斯过程进行了基准测试。代码可在 https://github.com/tgcsaba/GPSig 获取。

代码仓库

tgcsaba/GPSig
官方
tf
GitHub 中提及

基准测试

基准方法指标
time-series-classification-onGP-Sig-LSTM
Accuracy: 0.991
NLL: 0.031
time-series-classification-onGP-LSTM
Accuracy: 0.233
NLL: 2.506
time-series-classification-onGP-KConv1D
Accuracy: 0.941
NLL: 0.409
time-series-classification-onGP-Sig
Accuracy: 0.979
NLL: 0.108
time-series-classification-onGP-Sig-GRU
Accuracy: 0.925
NLL: 0.258
time-series-classification-onGP-GRU
Accuracy: 0.114
NLL: 3.523
time-series-classification-on-arabicdigitsGP-KConv1D
Accuracy: 0.984
NLL: 0.050
time-series-classification-on-arabicdigitsGP-LSTM
Accuracy: 0.985
NLL: 0.082
time-series-classification-on-arabicdigitsGP-Sig-GRU
Accuracy: 0.994
NLL: 0.023
time-series-classification-on-arabicdigitsGP-GRU
Accuracy: 0.986
NLL: 0.066
time-series-classification-on-arabicdigitsGP-Sig
Accuracy: 0.979
NLL: 0.071
time-series-classification-on-arabicdigitsGP-Sig-LSTM
Accuracy: 0.992
NLL: 0.047
time-series-classification-on-auslanGP-Sig-GRU
Accuracy: 0.978
NLL: 0.123
time-series-classification-on-auslanGP-KConv1D
Accuracy: 0.784
NLL: 1.900
time-series-classification-on-auslanGP-Sig
Accuracy: 0.925
NLL: 0.550
time-series-classification-on-auslanGP-LSTM
Accuracy: 0.880
NLL: 0.650
time-series-classification-on-auslanGP-Sig-LSTM
Accuracy: 0.983
NLL: 0.106
time-series-classification-on-auslanGP-GRU
Accuracy: 0.949
NLL: 0.248
time-series-classification-on-cmusubject16GP-Sig-LSTM
Accuracy: 1.000
NLL: 0.088
time-series-classification-on-cmusubject16GP-GRU
Accuracy: 0.993
NLL: 0.089
time-series-classification-on-cmusubject16GP-Sig
Accuracy: 0.979
NLL: 0.089
time-series-classification-on-cmusubject16GP-Sig-GRU
Accuracy: 1.000
NLL: 0.040
time-series-classification-on-cmusubject16GP-LSTM
Accuracy: 0.924
NLL: 0.270
time-series-classification-on-cmusubject16GP-KConv1D
Accuracy: 0.897
NLL: 0.255
time-series-classification-on-digitshapesGP-Sig-LSTM
Accuracy: 1.000
NLL: 0.008
time-series-classification-on-digitshapesGP-Sig-GRU
Accuracy: 1.000
NLL: 0.035
time-series-classification-on-digitshapesGP-KConv1D
Accuracy: 1.000
NLL: 0.035
time-series-classification-on-digitshapesGP-Sig
Accuracy: 1.000
NLL: 0.021
time-series-classification-on-digitshapesGP-LSTM
Accuracy: 1.000
NLL: 0.013
time-series-classification-on-digitshapesGP-GRU
Accuracy: 0.812
NLL: 0.727
time-series-classification-on-ecgGP-KConv1D
Accuracy: 0.760
NLL: 0.543
time-series-classification-on-ecgGP-LSTM
Accuracy: 0.782
NLL: 0.496
time-series-classification-on-ecgGP-GRU
Accuracy: 0.734
NLL: 0.601
time-series-classification-on-ecgGP-Sig-LSTM
Accuracy: 0.816
NLL: 0.402
time-series-classification-on-ecgGP-Sig-GRU
Accuracy: 0.832
NLL: 0.431
time-series-classification-on-ecgGP-Sig
Accuracy: 0.848
NLL: 0.356
time-series-classification-on-japanesevowelsGP-Sig-LSTM
Accuracy: 0.981
NLL: 0.080
time-series-classification-on-japanesevowelsGP-Sig
Accuracy: 0.982
NLL: 0.069
time-series-classification-on-japanesevowelsGP-KConv1D
Accuracy: 0.986
NLL: 0.067
time-series-classification-on-japanesevowelsGP-LSTM
Accuracy: 0.982
NLL: 0.061
time-series-classification-on-japanesevowelsGP-Sig-GRU
Accuracy: 0.985
NLL: 0.053
time-series-classification-on-japanesevowelsGP-GRU
Accuracy: 0.986
NLL: 0.052
time-series-classification-on-kickvspunchGP-Sig-GRU
Accuracy: 0.820
NLL: 0.493
time-series-classification-on-kickvspunchGP-LSTM
Accuracy: 0.620
NLL: 0.696
time-series-classification-on-kickvspunchGP-KConv1D
Accuracy: 0.700
NLL: 0.662
time-series-classification-on-kickvspunchGP-GRU
Accuracy: 0.600
NLL: 0.674
time-series-classification-on-kickvspunchGP-Sig
Accuracy: 0.900
NLL: 0.224
time-series-classification-on-kickvspunchGP-Sig-LSTM
Accuracy: 0.900
NLL: 0.301
time-series-classification-on-librasGP-GRU
Accuracy: 0.742
NLL: 1.110
time-series-classification-on-librasGP-Sig-LSTM
Accuracy: 0.921
NLL: 0.320
time-series-classification-on-librasGP-Sig
Accuracy: 0.923
NLL: 0.259
time-series-classification-on-librasGP-KConv1D
Accuracy: 0.698
NLL: 1.608
time-series-classification-on-librasGP-LSTM
Accuracy: 0.776
NLL: 0.911
time-series-classification-on-librasGP-Sig-GRU
Accuracy: 0.899
NLL: 0.346
time-series-classification-on-netflowGP-LSTM
Accuracy: 0.928
NLL: 0.251
time-series-classification-on-netflowGP-Sig
Accuracy: 0.937
NLL: 0.189
time-series-classification-on-netflowGP-KConv1D
Accuracy: 0.945
NLL: 0.168
time-series-classification-on-netflowGP-GRU
Accuracy: 0.926
NLL: 0.194
time-series-classification-on-netflowGP-Sig-GRU
Accuracy: 0.921
NLL: 0.259
time-series-classification-on-netflowGP-Sig-LSTM
Accuracy: 0.931
NLL: 0.218
time-series-classification-on-pemsGP-Sig
Accuracy: 0.820
NLL: 0.520
time-series-classification-on-pemsGP-GRU
Accuracy: 0.769
NLL: 0.784
time-series-classification-on-pemsGP-Sig-GRU
Accuracy: 0.775
NLL: 1.100
time-series-classification-on-pemsGP-Sig-LSTM
Accuracy: 0.763
NLL: 0.704
time-series-classification-on-pemsGP-LSTM
Accuracy: 0.745
NLL: 1.194
time-series-classification-on-pemsGP-KConv1D
Accuracy: 0.794
NLL: 0.537
time-series-classification-on-pendigitsGP-Sig-LSTM
Accuracy: 0.928
NLL: 0.289
time-series-classification-on-pendigitsGP-Sig-GRU
Accuracy: 0.902
NLL: 0.399
time-series-classification-on-pendigitsGP-KConv1D
Accuracy: 0.946
NLL: 0.181
time-series-classification-on-pendigitsGP-GRU
Accuracy: 0.951
NLL: 0.187
time-series-classification-on-pendigitsGP-Sig
Accuracy: 0.955
NLL: 0.146
time-series-classification-on-pendigitsGP-LSTM
Accuracy: 0.953
NLL: 0.185
time-series-classification-on-shapesGP-KConv1D
Accuracy: 1.000
NLL: 0.012
time-series-classification-on-shapesGP-Sig
Accuracy: 1.000
NLL: 0.011
time-series-classification-on-shapesGP-GRU
Accuracy: 0.867
NLL: 0.168
time-series-classification-on-shapesGP-Sig-GRU
Accuracy: 1.000
NLL: 0.012
time-series-classification-on-shapesGP-LSTM
Accuracy: 1.000
NLL: 0.016
time-series-classification-on-shapesGP-Sig-LSTM
Accuracy: 1.000
NLL: 0.014
time-series-classification-on-uwaveGP-GRU
Accuracy: 0.763
NLL: 1.168
time-series-classification-on-uwaveGP-LSTM
Accuracy: 0.870
NLL: 0.745
time-series-classification-on-uwaveGP-KConv1D
Accuracy: 0.947
NLL: 0.189
time-series-classification-on-uwaveGP-Sig
Accuracy: 0.964
NLL: 0.140
time-series-classification-on-uwaveGP-Sig-GRU
Accuracy: 0.968
NLL: 0.121
time-series-classification-on-uwaveGP-Sig-LSTM
Accuracy: 0.970
NLL: 0.113
time-series-classification-on-waferGP-LSTM
Accuracy: 0.966
NLL: 0.105
time-series-classification-on-waferGP-GRU
Accuracy: 0.994
NLL: 0.029
time-series-classification-on-waferGP-Sig-GRU
Accuracy: 0.978
NLL: 0.081
time-series-classification-on-waferGP-Sig
Accuracy: 0.965
NLL: 0.105
time-series-classification-on-waferGP-KConv1D
Accuracy: 0.984
NLL: 0.085
time-series-classification-on-waferGP-Sig-LSTM
Accuracy: 0.988
NLL: 0.048
time-series-classification-on-walkvsrunGP-Sig-GRU
Accuracy: 1.000
NLL: 0.030
time-series-classification-on-walkvsrunGP-GRU
Accuracy: 1.000
NLL: 0.028
time-series-classification-on-walkvsrunGP-LSTM
Accuracy: 1.000
NLL: 0.048
time-series-classification-on-walkvsrunGP-KConv1D
Accuracy: 1.000
NLL: 0.066
time-series-classification-on-walkvsrunGP-Sig-LSTM
Accuracy: 1.000
NLL: 0.030
time-series-classification-on-walkvsrunGP-Sig
Accuracy: 1.000
NLL: 0.023

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
基于高斯过程和签名协方差的贝叶斯序贯数据学习 | 论文 | HyperAI超神经