3 个月前

一致性轨迹模型:学习扩散过程的概率流ODE轨迹

一致性轨迹模型:学习扩散过程的概率流ODE轨迹

摘要

一致性模型(Consistency Models, CM)(Song 等,2023)通过牺牲生成样本质量为代价,实现了基于得分的扩散模型采样加速,但缺乏一种自然的机制来在质量与速度之间进行权衡。为解决这一局限性,我们提出一致性轨迹模型(Consistency Trajectory Model, CTM),该模型是一种广义框架,可将 CM 与基于得分的模型作为特例统一纳入其中。CTM 训练单一神经网络,仅需一次前向传播即可输出得分(即对数密度的梯度),并支持在扩散过程的概率流常微分方程(Probability Flow Ordinary Differential Equation, ODE)轨迹上,任意初始时间与终末时间之间的无限制遍历。CTM 能够高效结合对抗训练与去噪得分匹配损失,显著提升生成性能,在 CIFAR-10 数据集上实现了单步扩散模型采样新最优的 FID 指标(FID = 1.73),在 64×64 分辨率下的 ImageNet 数据集上也达到了 FID = 1.92 的新纪录。此外,CTM 还引入了一类全新的采样方案,涵盖确定性与随机性方法,支持沿 ODE 解轨迹进行长距离跳跃式采样。随着计算资源的增加,CTM 始终能持续提升生成样本质量,避免了传统 CM 中常见的性能退化现象。更重要的是,与 CM 不同,CTM 可直接访问得分函数,从而便于集成扩散模型领域中成熟的可控生成与条件生成方法。同时,该得分函数的可访问性也支持对生成样本的似然(likelihood)计算。相关代码已开源,地址为:https://github.com/sony/ctm。

代码仓库

sony/ctm
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
image-generation-on-imagenet-64x64CTM (NFE 1)
Inception Score: 70.38
NFE: 1
image-generation-on-imagenet-64x64CTM
FID: 1.73
Inception Score: 64.29
NFE: 2

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
一致性轨迹模型:学习扩散过程的概率流ODE轨迹 | 论文 | HyperAI超神经