「周末AI课堂」理解GAN|机器学习你会遇到的“坑”

2019-01-27 21:36发布





鹰与蛇


想象在很久很久以前,鹰就喜欢抓蛇吃,但是蛇和鹰还不是现在看到的样子。后来蛇为了避免鹰抓到它,进化出了更为光滑的表皮,鹰为了更好的抓到蛇,爪子也变得越来越锋利。再后来,蛇为了不让鹰看到它,进化出了与环境融为一体的保护色,鹰为了更好的发现蛇,进化出了更为犀利的眼神,甚至学会了“打草惊蛇”的技术,让蛇在移动的过程中,更容易分辨蛇和草。他们在竞争中不断完善自己......

生成对抗网络,又叫做GAN(Generative Adversarial Networks),是一种著名的生成式模型。我们在上一节《理解变分自编码器》讲了VAE,它的主要特点是将自编码器中的隐向量重参数化为标准的正态分布,在训练过程中要同时存在encoder和decoder,然后我们将decoder部分独立拆分出来作为生成器。但在VAE中,我们可能会保持输入变量与输出图片的关系,可能会清楚图片中的鼻子或者眼睛分别代表着输入变量的哪一部分,但这并不意味着生成的图片就“好”。

“好”的标准是模糊的,我们判断图片好不好的一个可能标准就是,我们能不能分辨出哪些图片是编造的,哪些是真实的,如果不能,说明这个生成式模型是好的。GAN的出发点就是想训练出一个可以生成真实图片的生成式模型,训练中同时存在着生成器和判别器,生成器就类似于VAE的decoder,接受变量,生成图片,判别器就是一个普通的分类器,用来判断生成器所生成的图片是否是真实的。所以,GAN的图片要比VAE的图片更像真实图片。

判别器和生成器的关系就像是鹰和蛇的关系,生成器的目的是产生以假乱真的图片,判别器的目的是将假的图片和真的图片分开。



如图,我们将向量输入到生成器(Generator)中,生成器的输出进入到判别器(Discriminator)中,判别这是否是一个真实图片。

真正的关键问题是,他们是如何相互竞争(对抗)的。实际上,我们想让第二代的生成器可以骗过第一代的判别器,就是将判别器的参数全部固定,然后反向传播去调整生成器的参数,判别器无法识别出真假,此时就得到了第二代生成器。同时,我们想让第二代的判别器可以识别第一代的生成器,那么就将生成器的参数全部固定,反向传播去调整判别器的参数,得到第二代判别器。如此反复,直到达到我们满意的结果,作为生成式模型,我们最终需要的是生成器。

从数学角度来说,真实数据的分布为

,生成器所产生的是模拟分布

,判别器所判断的是

的一致性:



图为GAN的示意图,从左到右,分别表示进化的过程,z表示输入的向量,x表示生成的数据,绿线表示生成数据所构成的分布,黑线表示真实的分布,蓝线表示两者的差异,就是判别器本身,我们可以看到,一开始,模拟分布和真实分布存在不少的差异,但随着不断进化,最终,模拟分布和真实分布一致,代表着我们的生成器生成了以假乱真的图片。


判别器作为损失函数


GAN虽然属于无监督学习,因为我们没用到它的类别信息。事实上,我们在判别器中却隐含了类别,即生成器生成的图片是假的,额外引入的用以比较的图片是真实的。按照常理来说,我们所见到所有优雅的模型,只要涉及类别信息,都会用一个Loss function来表达这种不一致性,即Loss越大,表示越不属于这一类别,Loss越小,代表和这一类别越相似,然后我们对Loss function执行梯度下降,找出最优的参数值。

但是涉及到真实图片和虚假图片,我们并不好使用单纯的一个函数来表达这样的关系,因为函数关系要涉及到像素值的计算,真实和虚假图片的像素值关系并不是很容易表达,尤其是,很多时候,真实图片之间的像素值差别大,真实图片和虚假图片之间的像素值差别小。

GAN采用了判别器作为我们的Loss function。固定好判别器,更新生成器的参数,这一过程就好比我们使用优化算法来迭代更新模型的参数,为的只是获得更小的Loss;而固定好生成器,更新判别器的参数,这一过程就像是我们针对性将模型的Loss更换为更精准的Loss,衡量真实图片与虚假图片更精确的差别。

所以从整个GAN的框架来看,采用了判别器作为Loss function的监督学习。那么我们自然会想,如果判别器代表着损失函数,那么我们一开始就将判别器训练的特别好,应该对生成器的提高有很大作用,但在初始的GAN的使用中恰恰相反,如果很早就将判别器训练地很好,生成器会变差。

这是为什么呢?

让我们回到开头那个例子,如果鹰一开始就进化出了更为犀利的眼神,甚至学会了“打草惊蛇”的技术,让蛇在移动的过程中,更容易分辨蛇和草,那么蛇可能就不会先进化出保护色,而是在鹰打草惊蛇的时候就不再逃跑,但是没有保护色的情况下,显然跑是更好的选择。过于强大的鹰可能会使得蛇的进化方向乱掉,正如我们的判别器,太过强大,就让生成器变得无所适从。

这是辅助理解的很好手段。从数学上来说,这是因为我们衡量两者差异的所使用的KL散度或者JS散度(Jensen-Shannon )随着生成分布和真实分布的样本空间不一致性的加大,这两个度量会失效。就是由于判别器太过强大,或者生成器太过弱小而导致度量无法反映出分布的差异。我们当然可以更改判别器,但是这一操作并不优雅也不现实,我们对他的改进方式就是采用了Wasserstein距离,即便样本空间不一致性非常巨大,Wasserstein距离仍然能够反分布之间的差异。(这也是WGAN的基本思想)



如图,我们换用不同的对真实分布和生成分布的度量办法后,我们在训练过程中会发现普通的GAN会产生梯度消失,这就是由于分布差异过大引起的,而在WGAN中,梯度近似线性化的。



作者结束语

AI仍然是一个火热的话题,去年机缘巧合之下在读芯术开始了这趟旅途,至此,周末AI课堂总共60讲已经全部结束了。期间,我选择了一些自认为值得一讲的主题,并尽可能的减少公式的推导,代码也尽量简单易读,希望能让有所基础的人更进一步理解公式和模型背后的直觉意义,让有所兴趣的人能够一窥机器学习的框架,快速搭建并验证模型。

如能得到诸位的喜欢,不胜感激。

山高水长,我们江湖再见。

文章来源: https://www.toutiao.com/group/6651129670995542535/