PyTorch 高级篇(1):生成对抗网络(Generative Adversarial Networks)
参考代码
yunjey的 pytorch tutorial系列
生成对抗网络 学习资料
对于生成对抗网络还是很感兴趣的,可以用来生成图像, 生成的图像可以拿来当样本,hiahia。
论文 Arxiv地址
Generative Adversarial Networks
推荐 相关博客和教程
到底什么是生成式对抗网络GAN?By 微软亚洲研究院
莫烦视频:什么是生成对抗网络 (GAN)
莫烦视频:Pytorch实现GAN
PyTorch 实现
预处理阶段
1 2 3 4 5 6 7
| import os import torch import torchvision import torch.nn as nn from torchvision import transforms from torchvision.utils import save_image
|
1 2 3
| torch.cuda.set_device(1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
1 2 3 4 5 6 7 8
|
latent_size = 64 hidden_size = 256 image_size = 784 num_epochs = 200 batch_size = 100 sample_dir = 'samples'
|
1 2 3
| if not os.path.exists(sample_dir): os.makedirs(sample_dir)
|
1 2 3 4 5 6
|
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
MINIST 数据集
1 2 3 4 5
| mnist = torchvision.datasets.MNIST(root='../../../data/minist', train=True, transform=transform, download=True)
|
1 2 3 4
| data_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=batch_size, shuffle=True)
|
判别模型和生成模型的创建
1 2 3 4 5 6 7 8 9
|
D = nn.Sequential( nn.Linear(image_size, hidden_size), nn.LeakyReLU(0.2), nn.Linear(hidden_size, hidden_size), nn.LeakyReLU(0.2), nn.Linear(hidden_size, 1), nn.Sigmoid())
|
1 2 3 4 5 6 7 8 9
|
G = nn.Sequential( nn.Linear(latent_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Linear(hidden_size, image_size), nn.Tanh())
|
1 2 3 4
|
D = D.to(device) G = G.to(device)
|
1 2 3 4
| criterion = nn.BCELoss() d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002) g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)
|
1 2 3 4 5 6 7 8 9 10
|
def denorm(x): out = (x + 1) / 2 return out.clamp(0, 1)
def reset_grad(): d_optimizer.zero_grad() g_optimizer.zero_grad()
|
对抗生成训练
分两步:
- 固定生成模型,优化判别模型
- 固定判别模型,优化生成模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| total_step = len(data_loader) for epoch in range(num_epochs): for i, (images, _) in enumerate(data_loader): images = images.reshape(batch_size, -1).to(device) real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device)
outputs = D(images) d_loss_real = criterion(outputs, real_labels) real_score = outputs z = torch.randn(batch_size, latent_size).to(device) fake_images = G(z) outputs = D(fake_images) d_loss_fake = criterion(outputs, fake_labels) fake_score = outputs d_loss = d_loss_real + d_loss_fake reset_grad() d_loss.backward() d_optimizer.step()
z = torch.randn(batch_size, latent_size).to(device) fake_images = G(z) outputs = D(fake_images) g_loss = criterion(outputs, real_labels) reset_grad() g_loss.backward() g_optimizer.step() if (i+1) % 200 == 0: print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item())) if (epoch+1) == 1: images = images.reshape(images.size(0), 1, 28, 28) save_image(denorm(images), os.path.join(sample_dir, 'real_images.png')) fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28) save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))
|
Epoch [0/200], Step [200/600], d_loss: 0.0358, g_loss: 4.3944, D(x): 0.99, D(G(z)): 0.03
Epoch [0/200], Step [400/600], d_loss: 0.1378, g_loss: 5.3265, D(x): 0.96, D(G(z)): 0.06
Epoch [0/200], Step [600/600], d_loss: 0.2501, g_loss: 3.8991, D(x): 0.88, D(G(z)): 0.09
Epoch [1/200], Step [200/600], d_loss: 0.0795, g_loss: 5.9620, D(x): 0.97, D(G(z)): 0.05
Epoch [1/200], Step [400/600], d_loss: 0.0882, g_loss: 5.2908, D(x): 0.95, D(G(z)): 0.01
Epoch [1/200], Step [600/600], d_loss: 0.4405, g_loss: 3.4986, D(x): 0.84, D(G(z)): 0.10
Epoch [2/200], Step [200/600], d_loss: 0.9144, g_loss: 3.6864, D(x): 0.86, D(G(z)): 0.39
Epoch [2/200], Step [400/600], d_loss: 0.3320, g_loss: 3.6846, D(x): 0.89, D(G(z)): 0.09
Epoch [2/200], Step [600/600], d_loss: 0.3677, g_loss: 4.0595, D(x): 0.84, D(G(z)): 0.07
Epoch [3/200], Step [200/600], d_loss: 1.1544, g_loss: 1.7915, D(x): 0.64, D(G(z)): 0.25
Epoch [3/200], Step [400/600], d_loss: 0.2243, g_loss: 5.7187, D(x): 0.93, D(G(z)): 0.03
Epoch [3/200], Step [600/600], d_loss: 0.7260, g_loss: 3.2014, D(x): 0.84, D(G(z)): 0.34
Epoch [4/200], Step [200/600], d_loss: 0.3508, g_loss: 3.5732, D(x): 0.88, D(G(z)): 0.11
Epoch [4/200], Step [400/600], d_loss: 0.2381, g_loss: 3.2004, D(x): 0.87, D(G(z)): 0.05
Epoch [4/200], Step [600/600], d_loss: 1.0712, g_loss: 2.0204, D(x): 0.75, D(G(z)): 0.32
Epoch [5/200], Step [200/600], d_loss: 0.8128, g_loss: 1.9196, D(x): 0.75, D(G(z)): 0.19
Epoch [5/200], Step [400/600], d_loss: 0.6710, g_loss: 3.2527, D(x): 0.88, D(G(z)): 0.29
Epoch [5/200], Step [600/600], d_loss: 0.1591, g_loss: 3.2350, D(x): 0.94, D(G(z)): 0.07
Epoch [6/200], Step [200/600], d_loss: 0.1656, g_loss: 2.8785, D(x): 0.96, D(G(z)): 0.09
Epoch [6/200], Step [400/600], d_loss: 0.1527, g_loss: 5.3023, D(x): 0.94, D(G(z)): 0.04
Epoch [6/200], Step [600/600], d_loss: 0.1915, g_loss: 3.6295, D(x): 0.96, D(G(z)): 0.09
Epoch [7/200], Step [200/600], d_loss: 0.2680, g_loss: 3.8414, D(x): 0.96, D(G(z)): 0.16
Epoch [7/200], Step [400/600], d_loss: 0.3089, g_loss: 3.1775, D(x): 0.91, D(G(z)): 0.06
Epoch [7/200], Step [600/600], d_loss: 0.1621, g_loss: 3.6864, D(x): 0.96, D(G(z)): 0.09
Epoch [8/200], Step [200/600], d_loss: 0.1415, g_loss: 5.7511, D(x): 0.96, D(G(z)): 0.07
Epoch [8/200], Step [400/600], d_loss: 0.2798, g_loss: 4.0020, D(x): 0.93, D(G(z)): 0.12
Epoch [8/200], Step [600/600], d_loss: 0.1186, g_loss: 6.0730, D(x): 0.95, D(G(z)): 0.01
Epoch [9/200], Step [200/600], d_loss: 0.1194, g_loss: 3.9334, D(x): 0.95, D(G(z)): 0.04
Epoch [9/200], Step [400/600], d_loss: 0.2538, g_loss: 4.9895, D(x): 0.95, D(G(z)): 0.09
Epoch [9/200], Step [600/600], d_loss: 0.0858, g_loss: 5.8676, D(x): 0.99, D(G(z)): 0.06
Epoch [10/200], Step [200/600], d_loss: 0.1646, g_loss: 5.9101, D(x): 0.96, D(G(z)): 0.08
Epoch [10/200], Step [400/600], d_loss: 0.1347, g_loss: 6.3711, D(x): 0.94, D(G(z)): 0.01
Epoch [10/200], Step [600/600], d_loss: 0.1253, g_loss: 7.6121, D(x): 0.93, D(G(z)): 0.01
........................
Epoch [100/200], Step [200/600], d_loss: 0.5259, g_loss: 1.8732, D(x): 0.89, D(G(z)): 0.27
Epoch [100/200], Step [400/600], d_loss: 0.9937, g_loss: 1.6773, D(x): 0.65, D(G(z)): 0.24
Epoch [100/200], Step [600/600], d_loss: 0.8688, g_loss: 1.7864, D(x): 0.65, D(G(z)): 0.19
Epoch [101/200], Step [200/600], d_loss: 0.7689, g_loss: 1.8882, D(x): 0.73, D(G(z)): 0.23
Epoch [101/200], Step [400/600], d_loss: 0.8216, g_loss: 1.7205, D(x): 0.81, D(G(z)): 0.36
Epoch [101/200], Step [600/600], d_loss: 0.9299, g_loss: 1.3693, D(x): 0.73, D(G(z)): 0.31
Epoch [102/200], Step [200/600], d_loss: 0.8539, g_loss: 2.1614, D(x): 0.75, D(G(z)): 0.29
Epoch [102/200], Step [400/600], d_loss: 1.1440, g_loss: 1.2271, D(x): 0.80, D(G(z)): 0.47
Epoch [102/200], Step [600/600], d_loss: 0.8884, g_loss: 1.9186, D(x): 0.71, D(G(z)): 0.26
Epoch [103/200], Step [200/600], d_loss: 0.8597, g_loss: 1.6677, D(x): 0.76, D(G(z)): 0.29
Epoch [103/200], Step [400/600], d_loss: 0.7532, g_loss: 1.6774, D(x): 0.84, D(G(z)): 0.32
Epoch [103/200], Step [600/600], d_loss: 0.7142, g_loss: 2.0797, D(x): 0.71, D(G(z)): 0.17
Epoch [104/200], Step [200/600], d_loss: 0.8159, g_loss: 1.5567, D(x): 0.77, D(G(z)): 0.31
Epoch [104/200], Step [400/600], d_loss: 0.6156, g_loss: 2.4817, D(x): 0.82, D(G(z)): 0.25
Epoch [104/200], Step [600/600], d_loss: 0.9245, g_loss: 1.7658, D(x): 0.70, D(G(z)): 0.28
Epoch [105/200], Step [200/600], d_loss: 0.7685, g_loss: 2.1336, D(x): 0.75, D(G(z)): 0.29
Epoch [105/200], Step [400/600], d_loss: 0.7089, g_loss: 2.3518, D(x): 0.77, D(G(z)): 0.24
Epoch [105/200], Step [600/600], d_loss: 0.9698, g_loss: 1.7970, D(x): 0.67, D(G(z)): 0.27
Epoch [106/200], Step [200/600], d_loss: 0.8346, g_loss: 1.7312, D(x): 0.74, D(G(z)): 0.28
Epoch [106/200], Step [400/600], d_loss: 0.8106, g_loss: 1.6253, D(x): 0.84, D(G(z)): 0.37
Epoch [106/200], Step [600/600], d_loss: 0.7766, g_loss: 2.2826, D(x): 0.76, D(G(z)): 0.27
Epoch [107/200], Step [200/600], d_loss: 0.8676, g_loss: 2.3836, D(x): 0.69, D(G(z)): 0.25
Epoch [107/200], Step [400/600], d_loss: 0.6806, g_loss: 1.9751, D(x): 0.79, D(G(z)): 0.25
Epoch [107/200], Step [600/600], d_loss: 0.6495, g_loss: 2.2493, D(x): 0.78, D(G(z)): 0.23
Epoch [108/200], Step [200/600], d_loss: 0.6427, g_loss: 2.2878, D(x): 0.74, D(G(z)): 0.18
Epoch [108/200], Step [400/600], d_loss: 1.0293, g_loss: 1.6300, D(x): 0.70, D(G(z)): 0.33
Epoch [108/200], Step [600/600], d_loss: 0.9008, g_loss: 1.7924, D(x): 0.71, D(G(z)): 0.29
Epoch [109/200], Step [200/600], d_loss: 0.7724, g_loss: 1.9093, D(x): 0.77, D(G(z)): 0.27
Epoch [109/200], Step [400/600], d_loss: 0.9510, g_loss: 1.7266, D(x): 0.66, D(G(z)): 0.23
Epoch [109/200], Step [600/600], d_loss: 0.9278, g_loss: 2.0237, D(x): 0.69, D(G(z)): 0.27
........................
Epoch [190/200], Step [200/600], d_loss: 1.0075, g_loss: 1.3284, D(x): 0.75, D(G(z)): 0.41
Epoch [190/200], Step [400/600], d_loss: 0.9317, g_loss: 1.5531, D(x): 0.68, D(G(z)): 0.31
Epoch [190/200], Step [600/600], d_loss: 0.9142, g_loss: 2.0236, D(x): 0.71, D(G(z)): 0.33
Epoch [191/200], Step [200/600], d_loss: 0.9734, g_loss: 1.7790, D(x): 0.65, D(G(z)): 0.28
Epoch [191/200], Step [400/600], d_loss: 1.2256, g_loss: 1.2511, D(x): 0.73, D(G(z)): 0.48
Epoch [191/200], Step [600/600], d_loss: 0.9106, g_loss: 1.3276, D(x): 0.72, D(G(z)): 0.33
Epoch [192/200], Step [200/600], d_loss: 1.0399, g_loss: 1.3145, D(x): 0.64, D(G(z)): 0.32
Epoch [192/200], Step [400/600], d_loss: 0.8642, g_loss: 1.4594, D(x): 0.72, D(G(z)): 0.32
Epoch [192/200], Step [600/600], d_loss: 0.8914, g_loss: 1.5153, D(x): 0.76, D(G(z)): 0.35
Epoch [193/200], Step [200/600], d_loss: 0.8702, g_loss: 1.6950, D(x): 0.64, D(G(z)): 0.22
Epoch [193/200], Step [400/600], d_loss: 0.9578, g_loss: 1.5045, D(x): 0.67, D(G(z)): 0.31
Epoch [193/200], Step [600/600], d_loss: 1.1071, g_loss: 1.1405, D(x): 0.64, D(G(z)): 0.37
Epoch [194/200], Step [200/600], d_loss: 0.9798, g_loss: 1.4018, D(x): 0.65, D(G(z)): 0.31
Epoch [194/200], Step [400/600], d_loss: 0.8613, g_loss: 1.5124, D(x): 0.73, D(G(z)): 0.32
Epoch [194/200], Step [600/600], d_loss: 1.0008, g_loss: 1.8627, D(x): 0.66, D(G(z)): 0.28
Epoch [195/200], Step [200/600], d_loss: 0.9719, g_loss: 1.7610, D(x): 0.60, D(G(z)): 0.22
Epoch [195/200], Step [400/600], d_loss: 0.9135, g_loss: 1.3339, D(x): 0.78, D(G(z)): 0.38
Epoch [195/200], Step [600/600], d_loss: 0.9025, g_loss: 1.4633, D(x): 0.71, D(G(z)): 0.31
Epoch [196/200], Step [200/600], d_loss: 0.7495, g_loss: 1.9101, D(x): 0.77, D(G(z)): 0.27
Epoch [196/200], Step [400/600], d_loss: 1.0649, g_loss: 1.3968, D(x): 0.70, D(G(z)): 0.40
Epoch [196/200], Step [600/600], d_loss: 0.9115, g_loss: 1.4607, D(x): 0.65, D(G(z)): 0.28
Epoch [197/200], Step [200/600], d_loss: 0.9223, g_loss: 1.5513, D(x): 0.72, D(G(z)): 0.34
Epoch [197/200], Step [400/600], d_loss: 1.1988, g_loss: 1.2663, D(x): 0.62, D(G(z)): 0.36
Epoch [197/200], Step [600/600], d_loss: 0.9683, g_loss: 1.2183, D(x): 0.74, D(G(z)): 0.39
Epoch [198/200], Step [200/600], d_loss: 0.8696, g_loss: 1.6043, D(x): 0.68, D(G(z)): 0.28
Epoch [198/200], Step [400/600], d_loss: 1.1423, g_loss: 1.5558, D(x): 0.68, D(G(z)): 0.39
Epoch [198/200], Step [600/600], d_loss: 1.0431, g_loss: 1.5314, D(x): 0.68, D(G(z)): 0.39
Epoch [199/200], Step [200/600], d_loss: 1.0278, g_loss: 1.6758, D(x): 0.59, D(G(z)): 0.25
Epoch [199/200], Step [400/600], d_loss: 0.9197, g_loss: 1.6217, D(x): 0.66, D(G(z)): 0.30
Epoch [199/200], Step [600/600], d_loss: 0.8748, g_loss: 1.5809, D(x): 0.71, D(G(z)): 0.31
结果展示
1 2 3 4
| import matplotlib.pyplot as plt import matplotlib.image as mpimg import numpy as np
|
real image
1 2 3 4 5
| realPath = './samples/real_images.png' realImage = mpimg.imread(realPath) plt.imshow(realImage) plt.axis('off') plt.show()
|
fake image 进化过程
下图分别为第1,5,195,200轮训练生成的结果。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| fakePath1 = './samples/fake_images-1.png' fakeImg1 = mpimg.imread(fakePath1)
fakePath5 = './samples/fake_images-5.png' fakeImg5 = mpimg.imread(fakePath5)
plt.figure() plt.subplot(1,2,1 ) plt.imshow(fakeImg1) plt.subplot(1,2,2 ) plt.imshow(fakeImg5) plt.axis('off') plt.show()
fakePath195 = './samples/fake_images-195.png' fakeImg195 = mpimg.imread(fakePath195)
fakePath200 = './samples/fake_images-200.png' fakeImg200 = mpimg.imread(fakePath200)
plt.figure() plt.subplot(1,2,1 ) plt.imshow(fakeImg195) plt.subplot(1,2,2 ) plt.imshow(fakeImg200) plt.axis('off') plt.show()
|