文章信息
题目:GAN:Generative Adversarial Nets
原文:https://arxiv.org/pdf/1406.2661.pdf
代码:www.github.com/goodfeli/adversarial
数据集:MNISTCIFAR-10 he Toronto Face Database
一、简述
??GAN,即对抗生成网络,最初由由Ian Goodfellow于2014年提出,GAN网络在图像生成、转化、修复合成等都有应用。最初GAN由两个部分构成——生成器G和判别器D。生成器G将随机输入的高斯噪声映射成一副生成图(假图);判别器D对输入图像进行判断计算该图来自生成器的概率(真图概率)。若生成器生成的假图骗过了判别器,则会被识别为真实图,反之假图。原始GAN整个框架是生成器G与判别器D两者之间的相互博弈的动态过程,该过程形象地表示为“对抗”。作者提到不需要任何马尔可夫链或类似推理,而马尔可夫和扩散模型有关。
二、GAN训练
2.1 GAN目标函数
??符号含义:G表示生成器,D表示判别器;x表示真实数据,Pdata表示真实数据的概率分布;z表示随机输入的高斯噪声,Pz表示生成数据的概率分布;G(z)表示生成图像,D(x)表示对真实图像的判别,D(G(x))表对生成图像的判别。
??生成器G实际上是G(z,theta_g),theta_g是生成器需要学习的参数,判别器D实际上是G(x,theta_d),theta_d是判别器需要学习的参数。若D(x)是非常完美的判别器,D(x)=1,则V(D,G)=0,而若有误分类时,D(x)<1时,log(D(x))和log(1-D(G(z))都会是负值。
??先从判别器D角度来看,判别器D希望尽可能区分真实图像x和生成图像G(z),因此D(x)必须尽可能大,D(G(z))尽可能小,也就是V(D,G)整体尽可能大,即对D取max V(D,G);从生成器G的角度来看,生成器G希望真实图像x和生成图像G(z)区分不开,即希望虚假数据G(z)可以尽可能骗过判别器D,也就是希望D(G(z))尽可能大,1-D(G(z)就更小,log(1-D(G(z))更小,也就是V(D,G)整体尽可能小,即对G取min V(D,G)。因此,这里要训练2个模型,即D和G,当D和G都不再变动时,达到一个平衡。
2.2 GAN训练示意图
??线条含义:直线z表示输入的随机噪声(一维标量),直线x表示真实图像数据(一维标量)对应黑线表示真实数据分布Pdata,z—>x表示生成模型把z映射成为x得到绿线表示生成器G的虚假数据的概率分布Pg。蓝线表示判别器D的输出,
??图(a)-?表示GAN在前3步所做的事情,图(d)表示GAN在最后1步所做的事情。
??图(a)-?:真实分布和生成不同,判别器可以区分真实和生成的图像。图(d);生成器G的概率密度分布慢慢的逼近真实数据集的概率密度分布,而判别器预测值也在不断下降,当出现下图(d)的情况时,(G(z))=0.5,即分不清输入图像到底是真实图像还是生成器伪造的假图。
2.3 GAN算法
??内循环for:要迭代k步。批量化采样m个噪音样本{z1,…,zm}和m个来自真实数据的样本{x1,…,xm},组成2m个大小的小批量,放入价值函数V(D,G)中求梯度(真实样本放入判别器D(xi)),噪音样本放入生成器G(zi)),对判别器参数求梯度来更新判别器。
??外循环for:采样m个噪音样本{z1,…,zm},放入价值函数V(D,G)的第二项里(噪音样本放入生成器G(z^i)),对生成器参数求梯度来更生成器(生成器和第一项无关,计算第二项即可)
??k是超参数,这里作者取k=1。先训练判别器,再训练生成器。个人认为之所以要先训练D是因为如果先更新了生成器,若生成器更新的不好,接着更新判别器,效果也会变得不好。若生成器更新的太好,1-logG(z^i))会变成0,接着更新判别器,时第二项为0,没有意义。