def init_params(model):
for p in model.parameters():
if (p.dim() > 1):
nn.init.xavier_normal_(p)
else :
nn.init.uniform_(p,0.1,0.2)
def imshow(img):
fig,ax = plt.subplots(2,8,figsize=(10,5))
img = img.view(-1,28,28).detach().cpu().numpy()
for i,npimg in enumerate(img):
if i < 8 :
ax[0][i].imshow(npimg,cmap="gray")
plt.axis("off")
else :
i = i%8
ax[1][i].imshow(npimg,cmap="gray")
plt.axis("off")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = 32
hidden_dim = 256
generator = Generator(input_dim,hidden_dim).to(device)
discriminator = Discriminator(hidden_dim).to(device)
generator.apply(init_params)
discriminator.apply(init_params)
optimizer_g = optim.Adam(generator.parameters(),lr=0.001)
optimizer_d = optim.Adam(discriminator.parameters(),lr=0.001)
criterion = nn.BCELoss()
generator.train()
discriminator.train()
bucket = []
fake_bucket = []
buffer= []
fake_buffer = []
for epoch in range(50):
avg_loss_g = 0
avg_loss_d = 0
for X_train,label in train_loader:
X_train = X_train.to(device)
z = torch.randn(batch_size,input_dim,device=device)
X_train = X_train.view(-1,28*28)
X_train_labels = torch.ones(batch_size,1,device=device)
outputs = discriminator(X_train)
loss_d = criterion(outputs,X_train_labels)
output = outputs
X_fake_train = generator(z)
X_fake_labels = torch.zeros(batch_size,1,device=device)
outputs = discriminator(X_fake_train)
loss_d_fake = criterion(outputs,X_fake_labels)
fake_output = outputs
loss_d_total = loss_d + loss_d_fake
optimizer_d.zero_grad()
loss_d_total.backward()
optimizer_d.step()
z = torch.randn(batch_size,input_dim,device=device)
X_fake_train = generator(z)
outputs = discriminator(X_fake_train)
loss_g = criterion(outputs,X_train_labels)
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
avg_loss_d += loss_d_total / len(train_loader)
avg_loss_g += loss_g / len(train_loader)
bucket.append(output.mean().item())
fake_bucket.append(fake_output.mean().item())
output, fake_output = evaluate_model()
buffer.append(output)
fake_buffer.append(fake_output)
print("epoch : {} 일때 판별자 loss : {} 생성자 loss : {}\n".format(epoch+1,avg_loss_d,avg_loss_g))
if (epoch + 1) % 50 == 0:
z = torch.randn(batch_size,input_dim,device=device)
img = generator(z)
imshow(img)
def evaluate_model():
output, fake_output = 0.0, 0.0
generator.eval()
discriminator.eval()
for X_test,label in test_loader:
X_test = X_test.view(-1,28*28)
X_test = X_test.to(device)
z = torch.randn(batch_size,input_dim,device=device)
with torch.autograd.no_grad():
output += discriminator(X_test).sum().item() / len(test_loader.dataset)
fake_output += discriminator(generator(z)).sum().item() / len(test_loader.dataset)
return output , fake_output