3.2 GAN的数学原理
3.2.1 最大似然估计
为了理解生成对抗网络的基本原理,我们首先要讨论一下最大似然估计,看它是如何运用在生成模型上的。在最大似然估计中,我们首先会对真实训练数据集定义一个概率分布函数Pdata(x),其中的x相当于真实数据集中的某个数据点。
同样地,为了逼近真实数据的概率分布,我们也会为生成模型定义一个概率分布函数Pmodel(x; θ),这个分布函数也是通过参数变量θ定义的。在实际的计算过程中,我们希望改变参数θ,从而使得生成模型概率分布Pmodel(x; θ)能够逼近真实数据概率分布Pdata(x)。
当然在实际运算中,我们是无法知道Pdata(x)的形式的,我们唯一可以做的是从真实数据集中采样大量的数据,也就是说从Pdata(x)中取出{x1, x2,…, xm},通过这些真实的样本数据,我们计算对应的生成模型概率分布Pmodel(x(i);θ)。上述的{x1, x2,…, xm}也就是所谓的训练集,例如当我们希望生成模型能够生成猫咪的图片,那要做的就是先从互联网上找出大量的真实猫咪图片作为训练集。
现在根据训练数据集可以写出概率函数,通过所有的真实样本计算出在生成模型中的概率并全部进行相乘。
现在最大似然估计的目标是通过上面这个概率的式子,寻找出一个θ*使得L最大化。这样做的实际含义是指,在给出真实训练集的前提下,我们希望生成模型能够在这些数据上具备最大的概率,这样才说明我们的生成模型在给出的训练集上能够逼近真实数据的概率分布。
相比于连乘,这里使用求和运算更简单一些,所以我们对所有的pmodel(x(i);θ)取一个对数,把相乘转化为相加。
对于上述公式,我们可以把求和近似转化为求log pmodel(x; θ)的期望值,然后我们可以推导出积分的形式。
我们可以通过图3-7去理解上面的推导过程,假设我们的训练数据是满足高斯分布的一维数据,最终训练后的生成模型概率分布应该能够满足尽可能多的训练样本点。
图3-7 生成模型概率分布
在推导出上述积分公式后,我们在不影响求解的情况下在式(3-9)的基础上减去一个与θ没有关系的常数项∫pdata(x)log pdata(x)dx,如下面的推导所示,我们需要找到一个θ*使得下面的推导结果最小。
之前我们在介绍VAE的时候提到了KL散度,它是一种计算概率分布之间相似程度的计算方法。现在我们来看一下KL散度的公式,我们设定两个概率分布分别为P和Q,在假定为连续随机变量的前提下,它们对应的概率密度函数分别为p(x)和q(x),我们可以写出如下公式:
从式(3-14)可以看出,当且仅当P = Q时,KL(P‖Q) = 0。此外我们也可以发现KL散度具备非负的特性,即KL(P‖Q) ≥ 0。但是从公式中我们也可以发现,KL散度不具备对称性,也就是说P对于Q的KL散度并不等于Q对于P的KL散度。
在特定情况下,通常是P用以表示数据的真实分布,而Q则表示数据的模型分布或是近似分布。那么让我们来对比一下之前推导的公式与KL散度,可以发现是完全一致的,那么我们可以继续将公式推导成KL散度的形式:
我们希望最小化真实数据分布与生成模型分布之间的KL散度,从而使得生成模型尽可能接近真实数据的分布。在实际实践中,我们是几乎不可能知道真实数据分布Pdata(x)的,而是需要使用训练数据形成的经验分布逼近真实数据分布Pdata(x)。
在实践中我们会发现使用最大似然估计方法的生成模型通常会比较模糊,原因是一般的简单模型无法使得pmodel(x; θ)真正逼近真实数据分布,因为真实数据是非常复杂的。为了模拟复杂分布,解决方法是采用神经网络(例如GAN)去实现pmodel(x; θ),可以把简单分布映射成几乎任何复杂的分布。
Ian在NIPS2016的文章中给出了基于似然估计的生成模型分类,如图3-8所示。
图3-8 基于似然估计的生成模型分类
图3-8中说明了基于似然估计的生成模型可以分为两个主要分支,一类是显式模型,另一类是隐式模型,两者的核心差别在于生成模型是否需要计算出一个明确的概率分布密度函数。在大部分情况下,研究生成模型的目的往往在于生成数据,我们对于分布密度函数是什么样的可能并没有太大的兴趣。本书的主角GAN属于后者,它解决了很多现有模型存在的问题,比如计算复杂度高、难以扩展到高维度等,当然它也引出了很多新的问题亟待研究者去解决。
3.2.2 GAN的数学推导
从之前几节我们可以了解到,生成模型会从一个输入空间将数据映射到生成空间,写成公式的形式是x = G(z)。通常我们的输入z会满足一个简单形式的随机分布,比如高斯分布或者均匀分布等,为了使得生成空间的数据分布能够尽可能地逼近真实数据分布,生成函数G会是一个神经网络的形式,通过神经网络我们可以模拟出各种完全不同的分布类型。
虽然我们可以清楚知道前置输入数据z的概率分布函数,但在经过一个神经网络的情况下我们难以计算最终的生成空间分布Pmodel(x) ,这样就无法计算3.2.1节中的概率函数L。
现在我们来看一下生成对抗网络是如何解决这个问题的。
首先看一下生成对抗网络中的代价函数,以判别器D为例,代价函数写作J(D),形式如下所示。后面我们会解释使用这种形式的原因。
对于生成器来说,它和判别器是紧密相关的,我们可以把两者看作一个零和博弈,它们的代价综合应该是零,所以生成器的代价函数应满足如下等式。
J(G) = −J(D) (3-17)
这样一来,我们可以设置一个价值函数V来表示J(G)和J(D)。
我们现在把问题变成了需要寻找一个合适的V(θ(D),θ(G)),使得J(G)和J(D)都尽可能小,也就是说对于判别器而言,V(θ(D),θ(G))越大越好,而对于生成器来说,则是V(θ(D),θ(G))越小越好,从而形成了两者之间的博弈关系。
在博弈论中,博弈双方的决策组合会形成一个纳什平衡点(Nash equilibrium),在这个博弈平衡点下博弈中的任何一方将无法通过自身的行为而增加自己的收益。这里有一个经典的囚徒困境例子来进一步说明纳什平衡点。两名囚犯被警方分开单独审讯,他们被告知的信息如下:如果一方招供而另一方不招供,则招供的一方将可以立即释放,而另一方会被判处10年监禁;如果双方都招供的话,每个人都被判处两年监禁;如果双方都不招供,则每个人都仅被判半年监禁。两名囚犯由于无法交流,必须做出对自己最有利的选择,从理性角度出发,选择招供是个人的最优决策,对方做出任何决定对于招供方都会是一个相对较好的结果,我们称这样的平衡为纳什平衡点。
在生成对抗网络中,我们要计算的纳什平衡点正是要寻找一个生成器G与判别器D,使得各自的代价函数最小,从上面的推导中也可以得出我们希望找到一个V(θ(D),θ(G)),使其对于生成器来说最小而对于判别器来说最大,我们可以把它定义成一个寻找极大极小值的问题,公式如下所示。
我们可以用图形化的方法去理解一下这个极大极小值的概念,一个很好的例子就是鞍点(saddle point),如图3-9所示,即在一个方向是函数的极大值点,而在另一个方向是函数的极小值点。
图3-9 鞍点
在上面公式的基础上,我们可以分别求出理想的判别器D*和生成器G*。
下面我们先来看一下如何求出理想的判别器,对于上述的D*,我们假定生成器G是固定的,令式子中的G(z) = x。推导如下。
我们现在的目标是希望寻找一个D使得V最大,希望积分中的项f(x) = pdata(x) log D(x) + pg(x)log(1 − D(x))无论x取何值都能最大。其中,我们已知pdata是固定的,之前我们也假定生成器G固定,所以pg也是固定的,所以可以很容易地求出D使得f(x)最大。我们假设x固定,f(x)对D(x)求导等于零,下面是求解D(x)的推导。
最终我们求得D*(x)的形式如下所示。
可以看出它是一个范围在0~1的值,这也符合我们判别器的模式,理想的判别器在接收到真实数据时应该判断为1,而对于生成数据则应该判断0。当生成数据分布与真实数据分布非常接近时,应该输出的结果为。
找到了D*之后,我们来推导一下生成器G*。现在先把D*(x)代入前面的积分式子中重新表示。
到了这一步,我们需要先介绍一个定义——Jensen–Shannon散度,这里简称JS散度。在概率统计中,JS散度也和前面提到的KL散度一样具备了测量两个概率分布相似程度的能力,它的计算是基于KL散度的,继承了KL散度的非负性等,但一点重要的区别是,JS散度具备了对称性。JS散度的公式如下,我们还是以P和Q作为例子,另外我们设定,KL为KL散度公式。
如果我们把KL的公式代入展开的话,结果如下。
现在我们回到之前的式子,可以把它转化成JS散度的形式。
对于上面的,由于JS散度是非负的,当且仅当pdata = pg时,上式可以取得全局最小值− log(4)。所以我们要求的最优生成器G*,正是要使得G*的分布pg = pdata。
到此为止,我们已经看到了生成对抗网络在数学理论上是如何成立的,在第4章的开始部分会介绍实际操作中是如何实现上述构想的。