import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torch.nn.init
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
transform = transforms.Compose([transforms.ToTensor()])
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)

image

train_loader = DataLoader(trainset,batch_size=128,shuffle=True,drop_last=True)
test_loader = DataLoader(testset,batch_size=128,shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(156)
if device =="cuda":
  torch.cuda_manula_seed(156)

모델 구현

class Autoencoder(nn.Module):
  def __init__(self):
    super().__init__()

    self.encoder = nn.Sequential(
        nn.Linear(28*28,128),
        nn.ReLU(),
        nn.Linear(128,64),
        nn.ReLU(),
        nn.Linear(64,12),
        nn.ReLU(),
        nn.Linear(12,3),
    )

    self.decoder = nn.Sequential(
        nn.Linear(3,12),
        nn.ReLU(),
        nn.Linear(12,64),
        nn.ReLU(),
        nn.Linear(64,128),
        nn.ReLU(),
        nn.Linear(128,28*28),
        nn.Sigmoid(),
    )
  def forward(self,x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)

    return encoded,decoded

model = Autoencoder().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
criterion = nn.MSELoss()

view_data = x[:5].view(-1,28*28)

for epoch in range(10):
  model.train()
  avg_loss = 0
  for i,(x,target) in enumerate(train_loader):

    x = x.view(-1,28*28).to(device)
    y = x.view(-1,28*28).to(device)
    target = target.to(device)

    encoded,decoded = model(x)
    loss = criterion(decoded,y)
    optimizer.zero_grad()

    loss.backward()
    optimizer.step()
    avg_loss += loss / len(train_loader)
  print("epoch : {}일때 loss : {}".format(epoch+1,avg_loss))

시각화

fig,ax = plt.subplots(2,5,figsize=(10,5))

for i in range(5):
  img = np.reshape(view_data.numpy()[i],(28,28))
  ax[0][i].imshow(img,cmap="gray")
  ax[0][i].set_xticks(())
  ax[0][i].set_yticks(())

for i in range(5):
  img = np.reshape(decoded_data.detach().numpy()[i],(28,28))
  ax[1][i].imshow(img,cmap="gray")
  ax[1][i].set_xticks(())
  ax[1][i].set_yticks(())

image

Denoise Auto Encoder

image

def add_noise(img):
  noise = torch.randn(img.size()) * 0.3
  noisy_img = img + noise
  return noisy_img

for epoch in range(10):
  model.train()
  avg_loss = 0
  for i,(x,target) in enumerate(train_loader):
    noisy_x = add_noise(x)
    noisy_x = noisy_x.view(-1,28*28).to(device)
    y = x.view(-1,28*28).to(device)
    target = target.to(device)

    encoded,decoded = model(noisy_x)
    loss = criterion(decoded,y)
    optimizer.zero_grad()

    loss.backward()
    optimizer.step()
    avg_loss += loss / len(train_loader)
  print("epoch : {}일때 loss : {}".format(epoch+1,avg_loss))

sample_data = x[:5].view(-1,28*28)
test_x = sample_data
_,decoded_data = model(test_x)

original_img = sample_data
noisy_img = add_noise(original_img)
_,recovered_img = model(noisy_img)

시각화

fig,ax = plt.subplots(3,5,figsize=(10,5))

for i in range(5):
  img = np.reshape(original_img.numpy()[i],(28,28))
  ax[0][i].imshow(img,cmap="gray")
  ax[0][i].set_xticks(())
  ax[0][i].set_yticks(())

for i in range(5):
  img = np.reshape(noisy_img.numpy()[i],(28,28))
  ax[1][i].imshow(img,cmap="gray")
  ax[1][i].set_xticks(())
  ax[1][i].set_yticks(())

for i in range(5):
  img = np.reshape(recovered_img.detach().numpy()[i],(28,28))
  ax[2][i].imshow(img,cmap="gray")
  ax[2][i].set_xticks(())
  ax[2][i].set_yticks(())

image

sample_data = x[0].view(-1,28*28)
original_x = sample_data[0]
noisy_x = add_noise(original_x)
_,recovered_x = model(noisy_x)

fig,ax = plt.subplots(1,3,figsize=(15,15))

original_img = original_x.view(28,28)
noisy_img = noisy_x.view(28,28)
recovered_img = recovered_x.view(28,28).detach().numpy()
ax[0].set_title("Original Img")
ax[0].imshow(original_img,cmap="gray")
ax[1].set_title("Noisy Img")
ax[1].imshow(noisy_img,cmap="gray")
ax[2].set_title("Recovered Img")
ax[2].imshow(recovered_img,cmap="gray")

image

Categories:

Updated: