
摘要
所提出的“判别器引导”(Discriminator Guidance)方法旨在提升预训练扩散模型的样本生成质量。该方法引入一个判别器,对去噪样本路径是否具有真实性提供显式监督。与生成对抗网络(GANs)不同,本方法无需联合训练得分网络与判别器网络。相反,我们先完成得分网络的训练,再单独训练判别器,从而确保判别器训练过程稳定且收敛迅速。在样本生成阶段,我们在预训练得分函数中添加一个辅助项,以欺骗判别器。该辅助项可将模型得分修正为在最优判别器下的数据得分,表明判别器以互补方式提升了得分估计的准确性。基于该算法,我们在 ImageNet 256×256 数据集上取得了当前最优的生成效果,FID 达到 1.83,召回率(recall)为 0.64,与验证集的 FID(1.68)和召回率(0.66)极为接近。相关代码已开源,地址为:https://github.com/alsdudrla10/DG。
代码仓库
alsdudrla10/DG
官方
pytorch
GitHub 中提及
alsdudrla10/DG_imagenet
官方
pytorch
GitHub 中提及
基准测试
| 基准 | 方法 | 指标 |
|---|---|---|
| conditional-image-generation-on-cifar-10 | EDM-G++ (conditional) | FID: 1.64 |
| image-generation-on-celeba-64x64 | STDDPM-G++ | FID: 1.34 |
| image-generation-on-cifar-10 | Discriminator Guidance (unconditional) | FID: 1.77 |
| image-generation-on-imagenet-256x256 | ADM-G++ (Recall) | FID: 4.45 |
| image-generation-on-imagenet-256x256 | Discriminator Guidance | FID: 1.83 |
| image-generation-on-imagenet-256x256 | ADM-G++ (FID) | FID: 3.18 |