3 个月前

软截断:一种面向高精度得分估计的基于得分的扩散模型通用训练技术

软截断:一种面向高精度得分估计的基于得分的扩散模型通用训练技术

摘要

近年来,扩散模型在图像生成任务中取得了最先进的性能。然而,先前关于扩散模型的实证研究暗示,密度估计能力与样本生成性能之间存在一种负相关关系。本文通过充分的实证证据表明,这种负相关现象的根源在于:密度估计主要依赖于较小的扩散时间,而样本生成则主要依赖于较大的扩散时间。然而,在整个扩散时间范围内训练一个表现良好的得分网络(score network)极具挑战性,因为不同扩散时间步上的损失尺度存在显著不平衡。为实现有效训练,本文提出一种通用性强的训练技术——软截断(Soft Truncation),该方法将原本固定且静态的截断超参数转化为一个随机变量,从而缓解损失尺度不平衡的问题。在实验中,软截断方法在CIFAR-10、CelebA、CelebA-HQ 256×256以及STL-10等多个数据集上均取得了当前最先进的性能。

代码仓库

Kim-Dongjun/Soft-Truncation
官方
pytorch
GitHub 中提及

基准测试

基准方法指标
image-generation-on-celeba-64x64DDPM++ (VP, NLL) + ST
FID: 2.9
bits/dimension: 1.96
image-generation-on-celeba-64x64UNCSN++ (RVE) + ST
bits/dimension: 1.97
image-generation-on-celeba-64x64DDPM++ (VP, FID) + ST
FID: 1.9
bits/dimension: 2.1
image-generation-on-celeba-hq-256x256UNCSN++ (RVE) + ST
FID: 7.16
image-generation-on-ffhq-256-x-256UDM (RVE) + ST
FID: 5.54
image-generation-on-imagenet-32x32DDPM++ (VP, NLL) + ST
FID: 8.42
Inception score: 11.82
bpd: 3.85
image-generation-on-lsun-bedroom-256-x-256UDM (RVE) + ST
FID: 4.57
image-generation-on-stl-10UNCSN++ (RVE) + ST
FID: 7.71
Inception score: 13.43

用 AI 构建 AI

从想法到上线——通过免费 AI 协同编程、开箱即用的环境和市场最优价格的 GPU 加速您的 AI 开发

AI 协同编程
即用型 GPU
最优价格
立即开始

Hyper Newsletters

订阅我们的最新资讯
我们会在北京时间 每周一的上午九点 向您的邮箱投递本周内的最新更新
邮件发送服务由 MailChimp 提供
软截断:一种面向高精度得分估计的基于得分的扩散模型通用训练技术 | 论文 | HyperAI超神经