资源技术动态用GAN来做图像生成

用GAN来做图像生成

2019-09-29 | |  108 |   0

原标题:用GAN来做图像生成,这是最好的方法       来源:CSDN博客

原文链接:https://blog.csdn.net/c2a2o2/article/details/78535795


在我们之前的文章中,我们学习了如何构造一个简单的 GAN 来生成 MNIST 手写图片。对于图像问题,卷积神经网络相比于简单地全连接的神经网络更具优势,因此,我们这一节我们将继续深入 GAN,通过融合卷积神经网络来对我们的 GAN 进行改进,实现一个深度卷积 GAN。


专栏中的所有代码都在我的 GitHub中,欢迎 star 与 fork。

本次代码在 NELSONZHAO/zhihu/dcgan,里面包含了两个文件:

dcgan_mnist:基于 MNIST 手写数据集构造深度卷积 GAN 模型

dcgan_cifar:基于 CIFAR 数据集构造深度卷积 GAN 模型


本文主要以 MNIST 为例进行介绍,两者在本质上没有差别,只在细微的参数上有所调整。由于穷学生资源有限,没有对模型增加迭代次数,也没有构造更深的模型。并且也没有选取像素很高的图像,高像素非常消耗计算量。本节只是一个抛砖引玉的作用,让大家了解 DCGAN 的结构,如果有资源的小伙伴可以自己去尝试其他更清晰的图片以及更深的结构,相信会取得很不错的结果。


工具

Python3

TensorFlow 1.0

Jupyter notebook


正文

整个正文部分将包括以下部分:

- 数据加载

- 模型输入

- Generator

- Discriminator

- Loss

- Optimizer

- 训练模型

- 可视化


数据加载


数据加载部分采用 TensorFlow 中的 input_data 接口来进行加载。关于加载细节在前面的文章中已经写了很多次啦,相信看过我文章的小伙伴对 MNIST 加载也非常熟悉,这里不再赘述。


模型输入


在 GAN 中,我们的输入包括两部分,一个是真实图片,它将直接输入给 discriminator 来获得一个判别结果;另一个是随机噪声,随机噪声将作为 generator 来生成图片的材料,generator 再将生成图片传递给 discriminator 获得一个判别结果。


01.jpg


上面的函数定义了输入图片与噪声图片两个 tensor。


Generator


生成器接收一个噪声信号,基于该信号生成一个图片输入给判别器。在上一篇专栏文章生成对抗网络(GAN)之 MNIST 数据生成中,我们的生成器是一个全连接层的神经网络,而本节我们将生成器改造为包含卷积结构的网络,使其更加适合处理图片输入。整个生成器结构如下:


02.jpg


我们采用了 transposed convolution 将我们的噪声图片转换为了一个与输入图片具有相同 shape 的生成图像。我们来看一下具体的实现代码:


03.jpg


上面的代码是整个生成器的实现细节,里面包含了一些 trick,我们来一步步地看一下。


首先我们通过一个全连接层将输入的噪声图像转换成了一个 1 x 4*4*512 的结构,再将其 reshape 成一个 [batch_size, 4, 4, 512] 的形状,至此我们其实完成了第一步的转换。接下来我们使用了一个对加速收敛及提高卷积神经网络性能中非常有效的方法——加入 BN(batch normalization),它的思想是归一化当前层输入,使它们的均值为 0 和方差为 1,类似于我们归一化网络输入的方法。它的好处在于可以加速收敛,并且加入 BN 的卷积神经网络受权重初始化影响非常小,具有非常好的稳定性,对于提升卷积性能有很好的效果。关于 batch normalization,我会在后面专栏中进行一个详细的介绍。


完成 BN 后,我们使用 Leaky ReLU 作为激活函数,在上一篇专栏中我们已经提过这个函数,这里不再赘述。最后加入 dropout 正则化。剩下的 transposed convolution 结构层与之类似,只不过在最后一层中,我们不采用 BN,直接采用 tanh 激活函数输出生成的图片。


在上面的 transposed convolution 中,很多小伙伴肯定会对每一层 size 的变化疑惑,在这里来讲一下在 TensorFlow 中如何来计算每一层 feature map 的 size。首先,在卷积神经网络中,假如我们使用一个 k x k 的 filter 对 m x m x d 的图片进行卷积操作,strides 为 s,在 TensorFlow 中,当我们设置 padding='same'时,卷积以后的每一个 feature map 的 height 和 width 为用GAN来做图像生成,这是最好的方法;当设置 padding='valid'时,每一个 feature map 的 height 和 width 为用GAN来做图像生成,这是最好的方法。那么反过来,如果我们想要进行 transposed convolution 操作,比如将 7 x 7 的形状变为 14 x 14,那么此时,我们可以设置 padding='same',strides=2 即可,与 filter 的 size 没有关系;而如果将 4 x 4 变为 7 x 7 的话,当设置 padding='valid'时,即用GAN来做图像生成,这是最好的方法,此时 s=1,k=4 即可实现我们的目标。


上面的代码中我也标注了每一步 shape 的变化。


Discriminator


Discriminator 接收一个图片,输出一个判别结果(概率)。其实 Discriminator 完全可以看做一个包含卷积神经网络的图片二分类器。结构如下:

04.jpg

实现代码如下:

05.jpg


上面代码其实就是一个简单的卷积神经网络图像识别问题,最终返回 logits(用来计算 loss)与 outputs。这里没有加入池化层的原因在于图片本身经过多层卷积以后已经非常小了,并且我们加入了 batch normalization 加速了训练,并不需要通过 max pooling 来进行特征提取加速训练。


Loss Function


06.jpg


Loss 部分分别计算 Generator 的 loss 与 Discriminator 的 loss,和之前一样,我们加入 label smoothing 防止过拟合,增强泛化能力。


Optimizer


GAN 中实际包含了两个神经网络,因此对于这两个神经网络要分开进行优化。代码如下:


07.jpg


这里的 Optimizer 和我们之前不同,由于我们使用了 TensorFlow 中的 batch normalization 函数,这个函数中有很多 trick 要注意。首先我们要知道,batch normalization 在训练阶段与非训练阶段的计算方式是有差别的,这也是为什么我们在使用 batch normalization 过程中需要指定 training 这个参数。上面使用 tf.control_dependencies 是为了保证在训练阶段能够一直更新 moving averages。具体参考 A Gentle Guide to Using Batch Normalization in Tensorflow - Rui Shu。


训练


到此为止,我们就完成了深度卷积 GAN 的构造,接着我们可以对我们的 GAN 来进行训练,并且定义一些辅助函数来可视化迭代的结果。代码太长就不放上来了,可以直接去我的 GitHub 下载。


我这里只设置了 5 轮 epochs,每隔 100 个 batch 打印一次结果,每一行代表同一个 epoch 下的 25 张图:


08.jpg


我们可以看出仅仅经过了少部分的迭代就已经生成非常清晰的手写数字,并且训练速度是非常快的。


09.jpg


上面的图是最后几次迭代的结果。我们可以回顾一下上一篇的一个简单的全连接层的 GAN,收敛速度明显不如深度卷积 GAN。


总结

到此为止,我们学习了一个深度卷积 GAN,并且看到相比于之前简单的 GAN 来说,深度卷积 GAN 的性能更加优秀。当然除了 MNST 数据集以外,小伙伴儿们还可以尝试很多其他图片,比如我们之前用到过的 CIFAR 数据集,我在这里也实现了一个 CIFAR 数据集的图片生成,我只选取了马的图片进行训练:


刚开始训练时:


10.jpg


训练 50 个 epochs:


11.jpg

这里我只设置了 50 次迭代,可以看到最后已经生成了非常明显的马的图像,可见深度卷积 GAN 的优势。

THE END

免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。

合作及投稿邮箱:E-mail:editor@tusaishared.com

上一篇:数字图像基础之图像取样和量化(Image Sampling and Quantization)

下一篇:生成对抗网络(GAN)应用于图像分类

用户评价
全部评价

热门资源

  • 应用笔画宽度变换...

    应用背景:是盲人辅助系统,城市环境中的机器导航...

  • GAN之根据文本描述...

    一些比较好玩的任务也就应运而生,比如图像修复、...

  • 端到端语音识别时...

    从上世纪 50 年代诞生到 2012 年引入 DNN 后识别效...

  • 人体姿态估计的过...

    人体姿态估计是计算机视觉中一个很基础的问题。从...

  • 谷歌发布TyDi QA语...

    为了鼓励对多语言问答技术的研究,谷歌发布了 TyDi...