这份攻略帮你“稳住”反复无常的 GAN

这份攻略帮你“稳住”反复无常的 GAN

首页休闲益智方块推土机更新时间:2024-05-07
前言——这两天主要分享一些技术文章

GAN 自 2014 年提出以来得到了广泛应用,BigGAN 等生成的以假乱真的图像更是引发了众多关注,但由于训练稳定性较差,GAN 的使用变得非常困难。本文列出了一些提高 GAN 训练稳定性的常用技术。

生成对抗网络(GAN)是一类非常强大的神经网络,具有非常广阔的应用前景。GAN 本质上是由两个相互竞争的神经网络(生成器和判别器)组成的系统。

GAN 的工作流程示意图。

给定一组目标样本,生成器会试图生成一些人造的样本,这些生成的样本能够欺骗判别器将其视为真实的目标样本,达到「以假乱真」的目的。而判别器则会试图将真实的(目标)样本与虚假的(生成)样本区分开来。通过这样循环往复的训练方法,我们最终可以得到一个能够很好地生成与目标样本相似的样本的生成器。

由于 GAN 几乎可以学会模拟出所有类型的数据分布,它有着非常广泛的应用场景。通常,GAN 被用来去除图片中的人为影响、超分辨率、姿势迁移以及任何类型的图像转换,如下所示:

使用 GAN 完成的图像变换。

然而,由于 GAN 的训练稳定性反复无常,使用 GAN 是十分困难的。诚然,许多研究人员已经提出了很好的解决方案来缓解 GAN 训练中涉及的一些问题。然而,这一领域的研究进展是如此之快,以至于人们很难跟上这些最新的有趣的想法。本文列出了一些常用的使 GAN 训练稳定的技术。

使用 GAN 的弊端

由于一系列原因,想要使用 GAN 是十分困难的。本节列举出了其中的一些原因:

1. 模式崩溃

自然的数据分布是极其复杂的多模态函数(也称多峰函数)。也就是说,数据分布有许多「峰」或「模式」。每个模态代表相似的数据样本聚集在一起,但是与其它的模态并不相同。

在模式崩溃的情况下,生成器会生成从属于有限模态集集合的样本。当生成器认为它可以通过生成单一模式的样本来欺骗鉴别器时,就会发生这种情况。也就是说,生成器只从这种模式生成样本。

上面一排图片表示没有发生模式崩溃的情况下 GAN 输出的样本。下面一排图片表示发生模式崩溃时 GAN 输出的样本。

判别器最终会发现这种模式是人为生成的。结果,生成器会直接转而生成另一种模式。这样的情况会无限循环下去,从本质上限制了生成样本的多样性。详细解释请参考博客《Mode collapse in GANs》(http://aiden.nibali.org/blog/2017-01-18-mode-collapse-gans/)

2. 收敛性

在 GAN 的训练过程中,一个普遍的问题就是「何时停止训练 GAN 模型?」由于在判别器损失降低的同时生成器的损失会增高(反之亦然),我们并不能基于损失函数的值就来判别 GAN 的收敛性。下图说明了这一点:

一张典型的 GAN 损失函数示意图。请注意,此图无法说明 GAN 的收敛性。

3. 质量

和前面提到的问题一样,很难定量地判断生成器何时能生成高质量的样本。向损失函数中加入额外的感知正则化项可以在一定程度上帮助我们缓解这种情况。

4. 评价标准

GAN 的目标函数说明了生成器(G)与判别器(D)这一对相互博弈的模型相对于其对手的性能,但却不能代表输出样本的质量或多样性。因此,我们需要能够在目标函数相同的情况下进行度量的独特的评价标准。

术语

在我们深入研究可能有助于提升 GAN 模型性能的技术之前,让我们回顾一些术语。

1. 下确界及上确界

简而言之,下确界是集合的最大下界,上确界是集合的最小上界,「上确界、下确界」与「最小值、最大值」的区别在于下确界和上确界不一定属于集合。

2. 散度度量

散度度量代表了两个分布之间的距离。传统的 GAN 本质上是最小化了真实数据分布和生成的数据分布之间的 Jensen Shannon 散度(JS 散度)。GAN 的损失函数可以被改写为最小化其它的散度度量,例如:Kulback Leibler 散度(KL 散度)或全变分距离。通常,Wasserstein GAN 最小化了推土机距离。

3. Kantorovich Rubenstein 对偶性

我们很难使用一些散度度量的原始形式进行优化。然而,它们的对偶形式(用上确界替换下确界,反之亦然)可能就较为容易优化。对偶原理为将一种形式转化为另一种形式提供了框架。详细解释请参考博客:《Wasserstein GAN and the Kantorovich-Rubinstein Duality》(https://vincentherrmann.github.io/blog/wasserstein/)

4. LiPSCHITZ 连续性

一个 Lipschitz 连续函数的变化速度是有限的。对具备 Lipschitz 连续性的函数来说,函数曲线上任一点的斜率的绝对值不能超过实数 K。这样的函数也被称为 K-Lipschitz 连续函数。

Lipschitz 连续性是 GAN 所期望满足的,因为它们会限制判别器的梯度,从而从根本上避免了梯度爆炸问题。另外,Kantorovich-Rubinstein 对偶性要求 Wasserstein GAN 也满足 Lipschitz 连续性,详细解释请参考博客:《Spectral Normalization Explained》(https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html)。

用于提升模型性能的技术

有许多技巧和技术可以被用来使 GAN 更加稳定和强大。为了保证本文的简洁性,我仅仅解释了一些相对来说较新或较为复杂的技术。在本节的末尾,我列举出了其它的各种各样的技巧和技术。

1. 替换损失函数

针对 GAN 存在的的缺点,最流行的修正方法之一是使用「Wasserstein GAN」。它本质上是使用「推土机距离」(Wasserstein-1 距离或 EM 距离)代替传统 GAN 的 Jensen Shannon 散度。然而,EM 距离的原始形式是难以进行优化的,因此我们使用它的对偶形式(通过 Kantorovich Rubenstein 对偶性计算得出)。这要求判别器满足「1-Lipschitz」,我们是通过裁剪判别器的权重来保证这一点的。

使用推土机距离的优点是,即使真实的和生成的样本的数据分布没有交集,推土机距离也是「连续的」,这与 JS 或 KL 散度不同。此外,此时生成图像的质量与损失函数值之间存在相关性。而使用推土机距离的缺点是,我们需要在每次更新生成器时更新好几个判别器(对于原始实现的每次生成器更新也是如此)。此外,作者声称,权值裁剪是一种糟糕的确保 1-Lipschitz 约束的方法。

与 Jensen Shannon 散度(如右图所示)不同,即使数据分布不是连续的,推土机距离(如左图所示)也是连续的。详细的解释请参阅论文《Wasserstein GAN》(https://arxiv.org/pdf/1701.07875.pdf)

另一种有趣的解决方案是采用均方损失而非对数损失。LSGAN 的作者认为,传统的 GAN 损失函数并没有提供足够的刺激来「拉动」生成的数据分布逼近真实的数据分布。

原始 GAN 损失函数中的对数损失并不影响生成数据与决策边界之间的距离(决策边界将真实数据和生成的数据分开)。另一方面,LSGAN 对远离决策边界的生成样本进行惩罚,本质上将生成的数据分布「拉向」实际的数据分布。它通过使用均方损失替代对数损失来做到这一点。详细解释请参考博客:《Least Squares GAN》(https://wiseodd.github.io/techblog/2017/03/02/least-squares-gan/)。

2. 两个时间尺度上的更新规则(TTUR)

在此方法中,我们为判别器和生成器使用了不同的学习率。通常,生成器使用较慢的更新规则,而判别器使用较快的更新规则。通过使用这种方法,我们只需对学习率进行微调,就可以以 1:1 的比例执行生成器和判别器的更新。值得注意的是,SAGAN 的实现就使用了这个方法。

3. 梯度惩罚

在论文「Improved Training of WGANs」中,作者声称权值裁剪(正如在原始的 WGAN 中执行的那样)导致一些优化问题的产生。作者认为权重裁剪迫使神经网络去学习「较为简单的近似」从而得到最优的数据分布,这导致 GAN 得到的最终结果质量变低。他们还声称,如果 WGAN 的超参数设置不正确,权重裁剪会导致梯度爆炸或梯度消失的问题。作者在损失函数中引入了一个简单的梯度惩罚规则,从而缓解了上述问题。除此之外,正如在原始的 WGAN 实现中那样,这样做还保证了 1-Lipschitz 连续性。

正如在原始的 WGAN-GP 论文中提到的,将梯度惩罚作为正则化项加入。

DRAGAN 的作者声称,当 GAN 中进行的博弈(即判别器和生成器互相进行对抗)达到了「局部均衡状态」时,模式崩溃现象就会发生。他们还声称,此时由判别器所贡献的梯度是非常「尖锐的」。使用这样的梯度惩罚能够很自然地帮助我们避开这些状态,大大提高训练的稳定性,并减少模式崩溃现象的发生。

4. 谱归一化

谱归一化是一种通常在判别器中使用的权值归一化技术,它能够优化训练过程(使训练过程更稳定),从本质上保证了判别器满足「K-Lipschitz 连续性」。

SAGAN 等实现也在生成器中使用了谱归一化技术。博文《Spectral Normalization Explained》(https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html)也指出,谱归一化比梯度惩罚的计算效率更高。

5. 展开和打包

正如博文《Mode collapse in GANs》所描述的,一个阻止模式崩溃发生的方法是在更新参数时预测「对策」。当判别器有机会对生成器的结果做出反应时(考虑到对策,就像 min-max 方法),展开(unrolled)的 GAN 就可以让生成器骗过判别器。

另一个阻止模式崩溃发生的方式是将属于同一类的一些样本「打包」,然后将其传给判别器。这种方法被 PacGAN 所采用,该论文声称它们减少了模式崩溃的发生。

6. 堆叠 GAN

单个的 GAN 可能不够强大,无法有效地处理某些任务。因此,我们可以使用连续放置的多个 GAN,其中每个 GAN 可以解决一个简化的问题模块。例如,FashionGAN 使用了两个 GAN 处理局部的图像转换任务。

FashionGAN 使用了两个 GAN 来执行局部的图像转换。

将这种情况推到极致,可以逐步提高 GAN 模型所面临问题的难度。例如,Progressive GAN(ProGAN)可以生成分辨率超高的高质量图像。

7. 相对 GAN

传统的 GAN 会度量生成数据是真实数据的概率。相对 GAN(Relativistic GAN)则会去度量生成数据比真实数据「更加真实」的概率。正如 RGAN 相关论文《The relativistic discriminator: a key element missing from standard GAN》中提到的那样,我们可以使用一个合适的距离来度量这种「相对真实性」。

图 B 为我们使用标准 GAN 损失得到的判别器的输出。图 C 为输出的曲线实际的样子。图 A 为 JS 散度的最优解。

作者还提到,当判别器达到最优状态时,其输出的概率 D(x)应该收敛到 0.5。然而,传统的 GAN 训练算法会迫使判别器为任何图像输出「真实」(即概率为 1)的结果。这在某种程度上阻止了判别器的输出概率达到其最优值。相对 GAN 也解决了这个问题,并且如下图所示,取得了非常显著的效果。

在 5000 轮迭代后,标准 GAN 得到的输出(左图),以及相对 GAN 得到的输出(右图)。

8. 自注意力机制

自注意力 GAN 的作者声称,用于生成图像的卷积操作关注的是局部传播的信息。也就是说,由于它们的感受野(restrictive receptive field)有限,它们忽略了在全局传播的关系。

将注意力映射(由黄色方框中的网络计算得出)加入到标准的卷积运算中。

自注意力生成对抗网络使图像生成任务能够进行注意力机制驱动的远距离依赖建模。自注意力机制是对于常规的卷积运算的补充。全局信息(远距离依赖)有助于生成更高质量的图像。网络可以选择忽略注意力机制,或将其与常规的卷积运算一同进行考虑。要想更细致地了解自注意力机制,请参阅论文《Self-Attention Generative Adversarial Networks》(https://arxiv.org/pdf/1805.08318.pdf)。

9. 其它各种各样的技术

下面是其它的一些被用来提升 GAN 模型性能的技术(不完全统计!):

你可以通过论文《Improved Techniques for Training GANs》以及博文《From GAN to WGAN》了解更多关于这些技术的信息。在下面的 GitHub 代码仓库中列举出了更多的技术:https://github.com/soumith/ganhacks。

评价指标

到目前为止,读者已经了解了提升 GAN 训练效果的方法,我们需要使用一些指标来量化证明这些方法有效。下面,本文将列举出一些常用的 GAN 模型的性能评价指标。

1. Inception(GoogleNet)得分

Inception 得分可以度量生成数据有多「真实」。

Inception Score 的计算方法。

上面的方程由两个部分(p(y|x) 和 p(y))组成。在这里,x 代表由生成器生成的图像,p(y|x) 是将图像 x 输入给一个预训练好的 Inception 网络(正如在原始实现中使用 ImageNet 数据集进行预训练,https://arxiv.org/pdf/1801.01973.pdf)时得到的概率分布。同时,p(y) 是边缘概率分布,可以通过对生成图像 x 的一些不同的样本求 p(y|x) 平均值计算得出。这两项代表了真实图像所需要满足的两种特性:

  1. 生成图像应该包含「有意义」的目标(清晰、不模糊的目标)。这就意味着 p(y|x) 应该具有「较小的熵」。也就是说,我们的 Inception 网络必须非常有把握地确定生成的图像从属于某个特定的类。
  2. 生成的图像应该要「多样」。这就意味着 p(y) 应该有「较大的熵」。换句话说,生成器应该在生成图像时使得每张图像代表不同类的标签(理想情况下)。

理想状况下 p(y|x) 和 p(y) 的示意图。这种情况下,二者的 KL 散度非常大。

如果一个随机变量是高度可预测的,那么它的熵就很小(即,p(y) 应该是有一个尖峰的分布)。相反,如果随机变量是不可预测的,其熵就应该很大(即 p(y|x) 应该是一个均匀分布)。如果这两个特性都得到了满足,我们应该认为 p(y|x) 和 p(y) 的 KL 散度很大。自然,Inception 得分(IS)越大越好。如果读者想要了解对 Inception 得分更加深入的分析,请参阅论文《A Note on the Inception Score》(https://arxiv.org/pdf/1801.01973.pdf)。

2. Fréchet Inception 距离(FID)

Inception 得分的一个不足之处在于,并没有对真实数据和生成数据的统计量(如均值和方差)进行比较。Fréchet 距离通过对比真实图像和生成图像的均值和方差解决了这个问题。Fréchet Inception 距离(FID)执行了与 Inception 得分相同的分析过程,但是它是在通过向预训练好的 Inception-v3 网络传入真实的和生成的图像后得到的特征图上完成的。FID 的公式如下所示:

FID 得分对比了真实的数据分布和生成数据分布的均值和方差。「Tr」代表矩阵的「迹」。

FID 得分越低越好,因为此时它表明生成图像的统计量与真实图像非常接近。

结语

为了克服 GAN 训练中的种种弊端,研究社区提出了许多解决方案和方法。然而,由于大量涌现的新研究成果,很难跟进所有有意义的新工作。因此,本文分享的细节是不详尽的,并且可能在不久的将来就会过时。但是,笔者希望本文可以为那些想要提高 GAN 模型性能的人提供一定的指导。

查看全文
大家还看了
也许喜欢
更多游戏

Copyright © 2024 妖气游戏网 www.17u1u.com All Rights Reserved