PyTorch

PyTorch 高级篇(1):生成对抗网络(Generative Adversarial Networks)

参考代码

yunjey的 pytorch tutorial系列

生成对抗网络 学习资料

对于生成对抗网络还是很感兴趣的,可以用来生成图像, 生成的图像可以拿来当样本,hiahia。

论文 Arxiv地址

Generative Adversarial Networks

推荐 相关博客和教程

到底什么是生成式对抗网络GAN?By 微软亚洲研究院

莫烦视频:什么是生成对抗网络 (GAN)

莫烦视频:Pytorch实现GAN

PyTorch 实现

预处理阶段

1
2
3
4
5
6
7
# 包
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
1
2
3
# 设备配置
torch.cuda.set_device(1) # 这句用来设置pytorch在哪块GPU上运行
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1
2
3
4
5
6
7
8
# 超参数设置
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'
1
2
3
# 如果没有文件夹就创建一个文件夹
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
1
2
3
4
5
6
# 图像处理模块:transform设置
# Image processing:归一化
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels
std=(0.5, 0.5, 0.5))])

MINIST 数据集

1
2
3
4
5
# 加载同时做transform预处理
mnist = torchvision.datasets.MNIST(root='../../../data/minist',
train=True,
transform=transform,
download=True)
1
2
3
4
# 数据加载器:GAN中只考虑判别模型和生成模型的对抗提高,无需设置训练集和测试集
data_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=batch_size,
shuffle=True)

判别模型和生成模型的创建

1
2
3
4
5
6
7
8
9
# 创建判别模型
# Discriminator
D = nn.Sequential(
nn.Linear(image_size, hidden_size), # 判别的输入时图像数据
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, hidden_size),
nn.LeakyReLU(0.2),
nn.Linear(hidden_size, 1),
nn.Sigmoid())
1
2
3
4
5
6
7
8
9
# 创建生成模型
# Generator
G = nn.Sequential(
nn.Linear(latent_size, hidden_size), # 生成的输入是随机数,可以自己定义
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, image_size),
nn.Tanh())
1
2
3
4
# 拷到计算设备上
# Device setting
D = D.to(device)
G = G.to(device)
1
2
3
4
# 设置损失函数和优化器
criterion = nn.BCELoss() # 二值交叉熵 Binary cross entropy loss
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
1
2
3
4
5
6
7
8
9
10
# 定义两个函数

def denorm(x):
out = (x + 1) / 2
return out.clamp(0, 1)

# 重置梯度
def reset_grad():
d_optimizer.zero_grad()
g_optimizer.zero_grad()

对抗生成训练

分两步:

  1. 固定生成模型,优化判别模型
  2. 固定判别模型,优化生成模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
total_step = len(data_loader)
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
images = images.reshape(batch_size, -1).to(device)

# 创建标签,随后会用于损失函数BCE loss的计算
real_labels = torch.ones(batch_size, 1).to(device) # true_label设为1,表示True
fake_labels = torch.zeros(batch_size, 1).to(device) # fake_label设为0,表示False

# ================================================================== #
# 训练判别模型 #
# ================================================================== #

# 计算real_损失
# 使用公式 BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x)),来计算realimage的判别损失
# 其中第二项永远为零,因为real_labels == 1
outputs = D(images)
d_loss_real = criterion(outputs, real_labels)
real_score = outputs


# 计算fake损失
# 生成模型根据随机输入生成fake_images
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
# 使用公式 BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x)),来计算fakeImage的判别损失
# 其中第一项永远为零,因为fake_labels == 0
outputs = D(fake_images)
d_loss_fake = criterion(outputs, fake_labels)
fake_score = outputs

# 反向传播和优化
d_loss = d_loss_real + d_loss_fake
reset_grad()
d_loss.backward()
d_optimizer.step()

# ================================================================== #
# 训练生成模型 #
# ================================================================== #

# 生成模型根据随机输入生成fake_images,然后判别模型进行判别
z = torch.randn(batch_size, latent_size).to(device)
fake_images = G(z)
outputs = D(fake_images)

# 训练生成模型,使之最大化 log(D(G(z)) ,而不是最小化 log(1-D(G(z)))
# 具体的解释在原文第三小节最后一段有解释
# 大致含义就是在训练初期,生成模型G还很菜,判别模型会拒绝高置信度的样本,因为这些样本与训练数据不同。
# 这样log(1-D(G(z)))就近乎饱和,梯度计算得到的值很小,不利于反向传播和训练。
# 换一种思路,通过计算最大化log(D(G(z)),就能够在训练初期提供较大的梯度值,利于快速收敛
g_loss = criterion(outputs, real_labels)

# 反向传播和优化
reset_grad()
g_loss.backward()
g_optimizer.step()

if (i+1) % 200 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
real_score.mean().item(), fake_score.mean().item()))

# 在第一轮保存训练数据图像
if (epoch+1) == 1:
images = images.reshape(images.size(0), 1, 28, 28)
save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

# 每一轮保存 生成的样本(即fake_images)
fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
Epoch [0/200], Step [200/600], d_loss: 0.0358, g_loss: 4.3944, D(x): 0.99, D(G(z)): 0.03
Epoch [0/200], Step [400/600], d_loss: 0.1378, g_loss: 5.3265, D(x): 0.96, D(G(z)): 0.06
Epoch [0/200], Step [600/600], d_loss: 0.2501, g_loss: 3.8991, D(x): 0.88, D(G(z)): 0.09
Epoch [1/200], Step [200/600], d_loss: 0.0795, g_loss: 5.9620, D(x): 0.97, D(G(z)): 0.05
Epoch [1/200], Step [400/600], d_loss: 0.0882, g_loss: 5.2908, D(x): 0.95, D(G(z)): 0.01
Epoch [1/200], Step [600/600], d_loss: 0.4405, g_loss: 3.4986, D(x): 0.84, D(G(z)): 0.10
Epoch [2/200], Step [200/600], d_loss: 0.9144, g_loss: 3.6864, D(x): 0.86, D(G(z)): 0.39
Epoch [2/200], Step [400/600], d_loss: 0.3320, g_loss: 3.6846, D(x): 0.89, D(G(z)): 0.09
Epoch [2/200], Step [600/600], d_loss: 0.3677, g_loss: 4.0595, D(x): 0.84, D(G(z)): 0.07
Epoch [3/200], Step [200/600], d_loss: 1.1544, g_loss: 1.7915, D(x): 0.64, D(G(z)): 0.25
Epoch [3/200], Step [400/600], d_loss: 0.2243, g_loss: 5.7187, D(x): 0.93, D(G(z)): 0.03
Epoch [3/200], Step [600/600], d_loss: 0.7260, g_loss: 3.2014, D(x): 0.84, D(G(z)): 0.34
Epoch [4/200], Step [200/600], d_loss: 0.3508, g_loss: 3.5732, D(x): 0.88, D(G(z)): 0.11
Epoch [4/200], Step [400/600], d_loss: 0.2381, g_loss: 3.2004, D(x): 0.87, D(G(z)): 0.05
Epoch [4/200], Step [600/600], d_loss: 1.0712, g_loss: 2.0204, D(x): 0.75, D(G(z)): 0.32
Epoch [5/200], Step [200/600], d_loss: 0.8128, g_loss: 1.9196, D(x): 0.75, D(G(z)): 0.19
Epoch [5/200], Step [400/600], d_loss: 0.6710, g_loss: 3.2527, D(x): 0.88, D(G(z)): 0.29
Epoch [5/200], Step [600/600], d_loss: 0.1591, g_loss: 3.2350, D(x): 0.94, D(G(z)): 0.07
Epoch [6/200], Step [200/600], d_loss: 0.1656, g_loss: 2.8785, D(x): 0.96, D(G(z)): 0.09
Epoch [6/200], Step [400/600], d_loss: 0.1527, g_loss: 5.3023, D(x): 0.94, D(G(z)): 0.04
Epoch [6/200], Step [600/600], d_loss: 0.1915, g_loss: 3.6295, D(x): 0.96, D(G(z)): 0.09
Epoch [7/200], Step [200/600], d_loss: 0.2680, g_loss: 3.8414, D(x): 0.96, D(G(z)): 0.16
Epoch [7/200], Step [400/600], d_loss: 0.3089, g_loss: 3.1775, D(x): 0.91, D(G(z)): 0.06
Epoch [7/200], Step [600/600], d_loss: 0.1621, g_loss: 3.6864, D(x): 0.96, D(G(z)): 0.09
Epoch [8/200], Step [200/600], d_loss: 0.1415, g_loss: 5.7511, D(x): 0.96, D(G(z)): 0.07
Epoch [8/200], Step [400/600], d_loss: 0.2798, g_loss: 4.0020, D(x): 0.93, D(G(z)): 0.12
Epoch [8/200], Step [600/600], d_loss: 0.1186, g_loss: 6.0730, D(x): 0.95, D(G(z)): 0.01
Epoch [9/200], Step [200/600], d_loss: 0.1194, g_loss: 3.9334, D(x): 0.95, D(G(z)): 0.04
Epoch [9/200], Step [400/600], d_loss: 0.2538, g_loss: 4.9895, D(x): 0.95, D(G(z)): 0.09
Epoch [9/200], Step [600/600], d_loss: 0.0858, g_loss: 5.8676, D(x): 0.99, D(G(z)): 0.06
Epoch [10/200], Step [200/600], d_loss: 0.1646, g_loss: 5.9101, D(x): 0.96, D(G(z)): 0.08
Epoch [10/200], Step [400/600], d_loss: 0.1347, g_loss: 6.3711, D(x): 0.94, D(G(z)): 0.01
Epoch [10/200], Step [600/600], d_loss: 0.1253, g_loss: 7.6121, D(x): 0.93, D(G(z)): 0.01

........................

Epoch [100/200], Step [200/600], d_loss: 0.5259, g_loss: 1.8732, D(x): 0.89, D(G(z)): 0.27
Epoch [100/200], Step [400/600], d_loss: 0.9937, g_loss: 1.6773, D(x): 0.65, D(G(z)): 0.24
Epoch [100/200], Step [600/600], d_loss: 0.8688, g_loss: 1.7864, D(x): 0.65, D(G(z)): 0.19
Epoch [101/200], Step [200/600], d_loss: 0.7689, g_loss: 1.8882, D(x): 0.73, D(G(z)): 0.23
Epoch [101/200], Step [400/600], d_loss: 0.8216, g_loss: 1.7205, D(x): 0.81, D(G(z)): 0.36
Epoch [101/200], Step [600/600], d_loss: 0.9299, g_loss: 1.3693, D(x): 0.73, D(G(z)): 0.31
Epoch [102/200], Step [200/600], d_loss: 0.8539, g_loss: 2.1614, D(x): 0.75, D(G(z)): 0.29
Epoch [102/200], Step [400/600], d_loss: 1.1440, g_loss: 1.2271, D(x): 0.80, D(G(z)): 0.47
Epoch [102/200], Step [600/600], d_loss: 0.8884, g_loss: 1.9186, D(x): 0.71, D(G(z)): 0.26
Epoch [103/200], Step [200/600], d_loss: 0.8597, g_loss: 1.6677, D(x): 0.76, D(G(z)): 0.29
Epoch [103/200], Step [400/600], d_loss: 0.7532, g_loss: 1.6774, D(x): 0.84, D(G(z)): 0.32
Epoch [103/200], Step [600/600], d_loss: 0.7142, g_loss: 2.0797, D(x): 0.71, D(G(z)): 0.17
Epoch [104/200], Step [200/600], d_loss: 0.8159, g_loss: 1.5567, D(x): 0.77, D(G(z)): 0.31
Epoch [104/200], Step [400/600], d_loss: 0.6156, g_loss: 2.4817, D(x): 0.82, D(G(z)): 0.25
Epoch [104/200], Step [600/600], d_loss: 0.9245, g_loss: 1.7658, D(x): 0.70, D(G(z)): 0.28
Epoch [105/200], Step [200/600], d_loss: 0.7685, g_loss: 2.1336, D(x): 0.75, D(G(z)): 0.29
Epoch [105/200], Step [400/600], d_loss: 0.7089, g_loss: 2.3518, D(x): 0.77, D(G(z)): 0.24
Epoch [105/200], Step [600/600], d_loss: 0.9698, g_loss: 1.7970, D(x): 0.67, D(G(z)): 0.27
Epoch [106/200], Step [200/600], d_loss: 0.8346, g_loss: 1.7312, D(x): 0.74, D(G(z)): 0.28
Epoch [106/200], Step [400/600], d_loss: 0.8106, g_loss: 1.6253, D(x): 0.84, D(G(z)): 0.37
Epoch [106/200], Step [600/600], d_loss: 0.7766, g_loss: 2.2826, D(x): 0.76, D(G(z)): 0.27
Epoch [107/200], Step [200/600], d_loss: 0.8676, g_loss: 2.3836, D(x): 0.69, D(G(z)): 0.25
Epoch [107/200], Step [400/600], d_loss: 0.6806, g_loss: 1.9751, D(x): 0.79, D(G(z)): 0.25
Epoch [107/200], Step [600/600], d_loss: 0.6495, g_loss: 2.2493, D(x): 0.78, D(G(z)): 0.23
Epoch [108/200], Step [200/600], d_loss: 0.6427, g_loss: 2.2878, D(x): 0.74, D(G(z)): 0.18
Epoch [108/200], Step [400/600], d_loss: 1.0293, g_loss: 1.6300, D(x): 0.70, D(G(z)): 0.33
Epoch [108/200], Step [600/600], d_loss: 0.9008, g_loss: 1.7924, D(x): 0.71, D(G(z)): 0.29
Epoch [109/200], Step [200/600], d_loss: 0.7724, g_loss: 1.9093, D(x): 0.77, D(G(z)): 0.27
Epoch [109/200], Step [400/600], d_loss: 0.9510, g_loss: 1.7266, D(x): 0.66, D(G(z)): 0.23
Epoch [109/200], Step [600/600], d_loss: 0.9278, g_loss: 2.0237, D(x): 0.69, D(G(z)): 0.27

........................

Epoch [190/200], Step [200/600], d_loss: 1.0075, g_loss: 1.3284, D(x): 0.75, D(G(z)): 0.41
Epoch [190/200], Step [400/600], d_loss: 0.9317, g_loss: 1.5531, D(x): 0.68, D(G(z)): 0.31
Epoch [190/200], Step [600/600], d_loss: 0.9142, g_loss: 2.0236, D(x): 0.71, D(G(z)): 0.33
Epoch [191/200], Step [200/600], d_loss: 0.9734, g_loss: 1.7790, D(x): 0.65, D(G(z)): 0.28
Epoch [191/200], Step [400/600], d_loss: 1.2256, g_loss: 1.2511, D(x): 0.73, D(G(z)): 0.48
Epoch [191/200], Step [600/600], d_loss: 0.9106, g_loss: 1.3276, D(x): 0.72, D(G(z)): 0.33
Epoch [192/200], Step [200/600], d_loss: 1.0399, g_loss: 1.3145, D(x): 0.64, D(G(z)): 0.32
Epoch [192/200], Step [400/600], d_loss: 0.8642, g_loss: 1.4594, D(x): 0.72, D(G(z)): 0.32
Epoch [192/200], Step [600/600], d_loss: 0.8914, g_loss: 1.5153, D(x): 0.76, D(G(z)): 0.35
Epoch [193/200], Step [200/600], d_loss: 0.8702, g_loss: 1.6950, D(x): 0.64, D(G(z)): 0.22
Epoch [193/200], Step [400/600], d_loss: 0.9578, g_loss: 1.5045, D(x): 0.67, D(G(z)): 0.31
Epoch [193/200], Step [600/600], d_loss: 1.1071, g_loss: 1.1405, D(x): 0.64, D(G(z)): 0.37
Epoch [194/200], Step [200/600], d_loss: 0.9798, g_loss: 1.4018, D(x): 0.65, D(G(z)): 0.31
Epoch [194/200], Step [400/600], d_loss: 0.8613, g_loss: 1.5124, D(x): 0.73, D(G(z)): 0.32
Epoch [194/200], Step [600/600], d_loss: 1.0008, g_loss: 1.8627, D(x): 0.66, D(G(z)): 0.28
Epoch [195/200], Step [200/600], d_loss: 0.9719, g_loss: 1.7610, D(x): 0.60, D(G(z)): 0.22
Epoch [195/200], Step [400/600], d_loss: 0.9135, g_loss: 1.3339, D(x): 0.78, D(G(z)): 0.38
Epoch [195/200], Step [600/600], d_loss: 0.9025, g_loss: 1.4633, D(x): 0.71, D(G(z)): 0.31
Epoch [196/200], Step [200/600], d_loss: 0.7495, g_loss: 1.9101, D(x): 0.77, D(G(z)): 0.27
Epoch [196/200], Step [400/600], d_loss: 1.0649, g_loss: 1.3968, D(x): 0.70, D(G(z)): 0.40
Epoch [196/200], Step [600/600], d_loss: 0.9115, g_loss: 1.4607, D(x): 0.65, D(G(z)): 0.28
Epoch [197/200], Step [200/600], d_loss: 0.9223, g_loss: 1.5513, D(x): 0.72, D(G(z)): 0.34
Epoch [197/200], Step [400/600], d_loss: 1.1988, g_loss: 1.2663, D(x): 0.62, D(G(z)): 0.36
Epoch [197/200], Step [600/600], d_loss: 0.9683, g_loss: 1.2183, D(x): 0.74, D(G(z)): 0.39
Epoch [198/200], Step [200/600], d_loss: 0.8696, g_loss: 1.6043, D(x): 0.68, D(G(z)): 0.28
Epoch [198/200], Step [400/600], d_loss: 1.1423, g_loss: 1.5558, D(x): 0.68, D(G(z)): 0.39
Epoch [198/200], Step [600/600], d_loss: 1.0431, g_loss: 1.5314, D(x): 0.68, D(G(z)): 0.39
Epoch [199/200], Step [200/600], d_loss: 1.0278, g_loss: 1.6758, D(x): 0.59, D(G(z)): 0.25
Epoch [199/200], Step [400/600], d_loss: 0.9197, g_loss: 1.6217, D(x): 0.66, D(G(z)): 0.30
Epoch [199/200], Step [600/600], d_loss: 0.8748, g_loss: 1.5809, D(x): 0.71, D(G(z)): 0.31

结果展示

1
2
3
4
#导入包
import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
import numpy as np

real image

1
2
3
4
5
realPath = './samples/real_images.png'
realImage = mpimg.imread(realPath)
plt.imshow(realImage) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

png

fake image 进化过程

下图分别为第1,5,195,200轮训练生成的结果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# 起始阶段
fakePath1 = './samples/fake_images-1.png'
fakeImg1 = mpimg.imread(fakePath1)

fakePath5 = './samples/fake_images-5.png'
fakeImg5 = mpimg.imread(fakePath5)

plt.figure()
plt.subplot(1,2,1 ) # 显示图片
plt.imshow(fakeImg1) # 显示图片
plt.subplot(1,2,2 ) # 显示图片
plt.imshow(fakeImg5) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

fakePath195 = './samples/fake_images-195.png'
fakeImg195 = mpimg.imread(fakePath195)

fakePath200 = './samples/fake_images-200.png'
fakeImg200 = mpimg.imread(fakePath200)

plt.figure()
plt.subplot(1,2,1 ) # 显示图片
plt.imshow(fakeImg195) # 显示图片
plt.subplot(1,2,2 ) # 显示图片
plt.imshow(fakeImg200) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

png

png