PyTorch 高级篇(2):变分自编码器(Variational Auto-Encoder)
参考代码
yunjey的 pytorch tutorial系列
变分自编码器 学习资料 自编码器有这些个作用,
数据去噪(去噪编码器)
可视化降维
生成数据(与GAN各有千秋)
文献
Tutorial on Variational Autoencoders
讲解视频
【深度学习】变分自编码机 Arxiv Insights出品 双语字幕by皮艾诺小叔(非直译)
讲解文章
花式解释AutoEncoder与VAE
如何使用变分自编码器VAE生成动漫人物形象
PyTorch 实现 预处理 1 2 3 4 5 6 7 8 import osimport torchimport torch.nn as nnimport torch.nn.functional as Fimport torchvisionfrom torchvision import transformsfrom torchvision.utils import save_image
1 2 3 torch.cuda.set_device(1 ) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
1 2 3 4 5 sample_dir = 'samples' if not os.path.exists(sample_dir): os.makedirs(sample_dir)
1 2 3 4 5 6 7 8 image_size = 784 h_dim = 400 z_dim = 20 num_epochs = 15 batch_size = 128 learning_rate = 1e-3
MINIST 数据集 1 2 3 4 5 6 7 8 9 dataset = torchvision.datasets.MNIST(root='../../../data/minist' , train=True , transform=transforms.ToTensor(), download=True ) data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True )
创建VAE模型(变分自编码器(Variational Auto-Encoder)) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 class VAE (nn.Module): def __init__ (self, image_size=784 , h_dim=400 , z_dim=20 ): super (VAE, self).__init__() self.fc1 = nn.Linear(image_size, h_dim) self.fc2 = nn.Linear(h_dim, z_dim) self.fc3 = nn.Linear(h_dim, z_dim) self.fc4 = nn.Linear(z_dim, h_dim) self.fc5 = nn.Linear(h_dim, image_size) def encode (self, x ): h = F.relu(self.fc1(x)) return self.fc2(h), self.fc3(h) def reparameterize (self, mu, log_var ): std = torch.exp(log_var/2 ) eps = torch.randn_like(std) return mu + eps * std def decode (self, z ): h = F.relu(self.fc4(z)) return F.sigmoid(self.fc5(h)) def forward (self, x ): mu, log_var = self.encode(x) z = self.reparameterize(mu, log_var) x_reconst = self.decode(z) return x_reconst, mu, log_var
1 2 model = VAE().to(device)
1 2 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
开始训练 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 for epoch in range (num_epochs): for i, (x, _) in enumerate (data_loader): x = x.to(device).view(-1 , image_size) x_reconst, mu, log_var = model(x) reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False ) kl_div = - 0.5 * torch.sum (1 + log_var - mu.pow (2 ) - log_var.exp()) loss = reconst_loss + kl_div optimizer.zero_grad() loss.backward() optimizer.step() if (i+1 ) % 100 == 0 : print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" .format (epoch+1 , num_epochs, i+1 , len (data_loader), reconst_loss.item(), kl_div.item())) with torch.no_grad(): z = torch.randn(batch_size, z_dim).to(device) out = model.decode(z).view(-1 , 1 , 28 , 28 ) save_image(out, os.path.join(sample_dir, 'sampled-{}.png' .format (epoch+1 ))) out, _, _ = model(x) x_concat = torch.cat([x.view(-1 , 1 , 28 , 28 ), out.view(-1 , 1 , 28 , 28 )], dim=3 ) save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png' .format (epoch+1 )))
/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py:1006: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
/home/ubuntu/anaconda3/lib/python3.6/site-packages/torch/nn/functional.py:52: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.
warnings.warn(warning.format(ret))
Epoch[1/15], Step [100/469], Reconst Loss: 9898.7285, KL Div: 3231.0195
Epoch[1/15], Step [200/469], Reconst Loss: 9985.5391, KL Div: 3290.1267
Epoch[1/15], Step [300/469], Reconst Loss: 9800.6211, KL Div: 3201.4980
Epoch[1/15], Step [400/469], Reconst Loss: 9444.1016, KL Div: 3259.1062
Epoch[2/15], Step [100/469], Reconst Loss: 9204.6201, KL Div: 3056.4475
Epoch[2/15], Step [200/469], Reconst Loss: 9729.0078, KL Div: 3206.0845
Epoch[2/15], Step [300/469], Reconst Loss: 9609.4307, KL Div: 3220.1729
Epoch[2/15], Step [400/469], Reconst Loss: 9514.4150, KL Div: 3206.0166
Epoch[3/15], Step [100/469], Reconst Loss: 9042.1270, KL Div: 3145.2937
Epoch[3/15], Step [200/469], Reconst Loss: 9773.1826, KL Div: 3235.4180
Epoch[3/15], Step [300/469], Reconst Loss: 9427.7207, KL Div: 3141.4922
Epoch[3/15], Step [400/469], Reconst Loss: 9658.2725, KL Div: 3235.2390
Epoch[4/15], Step [100/469], Reconst Loss: 9596.0439, KL Div: 3177.3101
Epoch[4/15], Step [200/469], Reconst Loss: 9158.8330, KL Div: 3114.7456
Epoch[4/15], Step [300/469], Reconst Loss: 9519.2754, KL Div: 3100.6924
Epoch[4/15], Step [400/469], Reconst Loss: 9318.7393, KL Div: 3098.9333
Epoch[5/15], Step [100/469], Reconst Loss: 9248.7139, KL Div: 3203.3230
Epoch[5/15], Step [200/469], Reconst Loss: 9914.3438, KL Div: 3244.7737
Epoch[5/15], Step [300/469], Reconst Loss: 9575.4922, KL Div: 3210.8545
Epoch[5/15], Step [400/469], Reconst Loss: 9519.7637, KL Div: 3243.2603
................................
Epoch[11/15], Step [400/469], Reconst Loss: 9872.5010, KL Div: 3267.5239
Epoch[12/15], Step [100/469], Reconst Loss: 9508.9053, KL Div: 3069.8406
Epoch[12/15], Step [200/469], Reconst Loss: 9340.8848, KL Div: 3093.4531
Epoch[12/15], Step [300/469], Reconst Loss: 9537.1279, KL Div: 3208.4387
Epoch[12/15], Step [400/469], Reconst Loss: 9205.0615, KL Div: 3125.3406
Epoch[13/15], Step [100/469], Reconst Loss: 9650.2803, KL Div: 3167.0171
Epoch[13/15], Step [200/469], Reconst Loss: 9609.6025, KL Div: 3179.3223
Epoch[13/15], Step [300/469], Reconst Loss: 9498.6650, KL Div: 3309.2681
Epoch[13/15], Step [400/469], Reconst Loss: 9823.6318, KL Div: 3218.4116
Epoch[14/15], Step [100/469], Reconst Loss: 9167.9990, KL Div: 3097.4619
Epoch[14/15], Step [200/469], Reconst Loss: 9712.9277, KL Div: 3222.7612
Epoch[14/15], Step [300/469], Reconst Loss: 9887.4297, KL Div: 3336.3618
Epoch[14/15], Step [400/469], Reconst Loss: 9485.8965, KL Div: 3180.0781
Epoch[15/15], Step [100/469], Reconst Loss: 9628.2295, KL Div: 3244.9995
Epoch[15/15], Step [200/469], Reconst Loss: 9556.5020, KL Div: 3147.9658
Epoch[15/15], Step [300/469], Reconst Loss: 9569.2588, KL Div: 3193.5071
Epoch[15/15], Step [400/469], Reconst Loss: 9334.9570, KL Div: 3074.2688
结果展示 1 2 3 4 import matplotlib.pyplot as plt import matplotlib.image as mpimg import numpy as np
重构图 1 2 3 4 5 reconsPath = './samples/reconst-55.png' Image = mpimg.imread(reconsPath) plt.imshow(Image) plt.axis('off' ) plt.show()
随机生成图 1 2 3 4 5 genPath = './samples/sampled-107.png' Image = mpimg.imread(genPath) plt.imshow(Image) plt.axis('off' ) plt.show()