VAE pytorch

    技术2026-02-27  8

    import torch import torch.nn as nn from torch.autograd import Variable import torch.optim as optim from torchvision import transforms,datasets import torch.nn.functional as F import os import scipy import numpy as np from scipy import misc import math batch_size = 64 latent_vector = 32 intermediate_vector = 256 num_class = 10 det = 1e-10 lamb = 2.5 train_data = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download= True) test_data = datasets.MNIST(root= './data/', train=True, transform=transforms.ToTensor(), download= True) train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size= batch_size, shuffle= True) test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size= batch_size, shuffle=True) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 32, 3, padding=1, stride=1), nn.LeakyReLU(negative_slope=0.2), nn.Conv2d(32, 32, 3, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.2), nn.Conv2d(32, 64, 3, padding=1, stride=1), nn.LeakyReLU(negative_slope=0.2), nn.Conv2d(64, 64, 3, padding=1, stride=2), # 7*7*64 nn.LeakyReLU(negative_slope=0.2), ) self.fc_mu = nn.Linear(7*7*64, latent_vector) self.fc_logvar = nn.Linear(7*7*64, latent_vector) self.fc = nn.Linear(latent_vector, 7*7*64) self.decoder = nn.Sequential( nn.ConvTranspose2d(64, 64, 4, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.2), nn.ConvTranspose2d(64, 32, 3, padding=1, stride=1), nn.LeakyReLU(negative_slope=0.2), nn.ConvTranspose2d(32, 32, 4, padding=1, stride=2), nn.LeakyReLU(negative_slope=0.2), nn.ConvTranspose2d(32, 1, 3, padding=1, stride=1), nn.Sigmoid() ) def Reparameter(self, mu, logvar): parameter = Variable(torch.randn(mu.size(0), mu.size(1))) return parameter * torch.exp(logvar/2) + mu def forward(self, x): mu = self.encoder(x) logvar = self.encoder(x) mu = self.fc_mu(mu.view(mu.size(0), -1)) logvar = self.fc_logvar(logvar.view(logvar.size(0), -1)) z = self.Reparameter(mu, logvar) vector = self.fc(z).view(z.size(0), 64, 7, 7) return self.decoder(vector), mu, logvar model = VAE() MSE_Loss = nn.MSELoss(size_average=False) #MSE_Loss = F.binary_cross_entropy(reduction='sum') def loss_function(input, output, mu, logvar): Mse = MSE_Loss(input, output) # Mse = 0.5 * torch.mean((input - output).pow(2), 0) KL_loss = 0.5 * torch.sum(-logvar + mu.pow(2) + logvar.exp() - 1) # KL_loss = -0.5 * (temp_logvar - z_se.pow(2)) # # KL_loss = torch.mean(torch.tensordot(torch.unsqueeze(y, 1), KL_loss), 0) # # cat_loss = torch.mean(y * torch.log(y + det), 0) # # return lamb * torch.sum(Mse) + torch.sum(KL_loss) + torch.sum(cat_loss) return Mse + KL_loss def save_image(output, size, path, Color): h, w = output.shape[1], output.shape[2] if Color is True: image = np.zeros((w * size[0], h * size[1], 3)) else: image = np.zeros((w * size[0], h * size[1])) for index, data in enumerate (output): i = index % size[0] j = math.floor(index / size[1]) if Color is True: image[h*j : h*j+h, w*i : w*i+w, :] = data else: image[h*j : h*j+j, w*i : w*i+w] = data scipy.misc.toimage((image*255), cmin=0, cmax=255).save(path) def rescale_image(image): return (image/1.5+0.5)*255 optimizer = optim.SGD(model.parameters(), lr= 0.0001) def train(): for epoch in range(1,10): for i, (data, _) in enumerate (train_loader): tensor_data = Variable(data) output, mu, logvar= model(tensor_data) optimizer.zero_grad() loss = loss_function(tensor_data, output, mu, logvar) loss.backward() optimizer.step() if i % 50 == 0: if not os.path.exists("./image"): os.mkdir("./image") np_output = output.detach().numpy() np_output = np_output.swapaxes(1,2).swapaxes(2,3) save_image(np_output, [8,8], './image/image_{}.png'.format(i), True) print("loss={}".format(loss)) train()
    Processed: 0.009, SQL: 9