原标题:简单理解 GAN
来源:AI 研习社] 链接:https://www.yanxishe.com/blogDetail/17823
GAN -----Generative Adversarial Nets(生成对抗网络),包含两个部分:生成器/辨别器
简单例子:
训练生成器生成数据,数据要求Mean=4,STD=1.5的高斯分布。
首先数据准备:
生成数据代码如下:
def sample_data(size,length=10):
data=[]
for i in range(size):
data.append(sorted(np.random.normal(4,1.5,length))) # sorted 是为了训练相对平滑
return np.array(data)
生成噪声数据:
def random_data(size,length=10):
data=[]
for i in range(size):
data.append(np.random.random(length))
return np.array(data)
GAN生成结构:
确定生成器和辨别器结构,假设使用最简单的全连接网络结构。判别器用于区分真实的高斯分布和生成器生成的假的高斯分布。
可以视为二分类问题,假设采用交叉熵损失:
分类正确交叉熵趋近于0,通过拟合辨别器的参数来最小化交叉熵。
训练:
将真实数据标为1,生成器生成数据标为0. 迭代训练辨别器。
生成器输出为辨别器的输入,固定辨别器的参数,将输入生成器的噪音样本标记为1,最小化损失函数,训练生成器。
反复迭代,最终生成器的生成能力ok。
一THE END一
免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。
合作及投稿邮箱:E-mail:editor@tusaishared.com