안녕하세요. 이번 게시물에서는 MNIST 손글씨 데이터셋을 가지고 랜덤한 값으로 노이즈를 인위적으로 섞은 뒤 오토엔코더를 통해 가우스 노이즈를 제거해보려고 합니다.
이미지에서 노이즈는 본래의 값이 아니지만 섞이게 되는 값을 의미하며 주로 초점 문제나 통신 과정에서 생기는 가우스 노이즈가 존재합니다.
오토엔코더는 비지도학습의 일종으로 input data가 라벨의 역할을 하게 됩니다. 따라서 input data를 라벨로 입력하고 input data에 가우시안 noise를 섞은 수정된 input을 실제 모델의 input에 입력시켜 줍니다.
그 후 가우스 노이즈가 섞은 이미지를 학습된 모델에 입력시켜 노이즈를 제거한 상태로 복원해보려고 합니다.
1) 가우시안 노이즈가 포함된 이미지 로드 함수
def load_image():
training_data = MNIST(root="./",train=False,download=True,transform=ToTensor())
labels = MNIST(root="./",train=False,download=True,transform=ToTensor())
images = []
for image in training_data:
noisy_input = gaussian_noise(image[0][0].clone().detach())
input_tensor = noisy_input.clone().detach()
images.append(torch.unsqueeze(input_tensor,dim=0))
return images, labels
def gaussian_noise(x, scale=0.2):
gaussian_data_x = x+np.random.normal(loc=0,scale=scale,size=x.shape)
gaussian_data_x = np.clip(gaussian_data_x, 0, 1)
gaussian_data_x = gaussian_data_x.type(torch.FloatTensor)
return gaussian_data_x
2) 실제 이미지를 load하고 1)에서 만든 함수를 통해 가우시안 노이즈 포함 dataset 생성
if __name__ == "__main__":
##DATA 준비
images, labels = load_image()
train_images, test_images, train_labels, test_labels = train_test_split(images,labels,test_size=0.2,random_state=777)
train_dataset = dataset(images=train_images, labels=train_labels)
test_dataset = dataset(images=test_images, labels=test_labels)
위와 같이 input data에 가우시안 노이즈를 섞은 데이터셋을 완성하고 input으로 가우시안 노이즈 포함 데이터셋, label로 일반 데이터셋을 오토엔코더에 입력하면 나머지는 04번째 게시물과 동일하게 진행 됩니다.
오토 엔코더 모델을 좀 단순하게 만들어서 완벽하게 노이즈가 제거되지는 않았고 압축된 피처맵을 복원할 때에 어느 정도 원본이미지에서 변형이 일어나는 것을 확인할 수 있었습니다.
import tqdm
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim.adam import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.utils.data.dataset import Dataset
from torchvision.datasets.mnist import MNIST
from sklearn.model_selection import train_test_split
def load_image():
training_data = MNIST(root="./",train=False,download=True,transform=ToTensor())
labels = MNIST(root="./",train=False,download=True,transform=ToTensor())
images = []
for image in training_data:
noisy_input = gaussian_noise(image[0][0].clone().detach())
input_tensor = noisy_input.clone().detach()
images.append(torch.unsqueeze(input_tensor,dim=0))
return images, labels
def gaussian_noise(x, scale=0.2):
gaussian_data_x = x+np.random.normal(loc=0,scale=scale,size=x.shape)
gaussian_data_x = np.clip(gaussian_data_x, 0, 1)
gaussian_data_x = gaussian_data_x.type(torch.FloatTensor)
return gaussian_data_x
class dataset(Dataset):
def __init__(self,images,labels):
super(dataset,self).__init__()
self.images = images
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self,index):
image = self.images[index]
label = self.labels[index][0] / 255
label = label.type(torch.FloatTensor)
return image, label
class BasicBlock(nn.Module):
def __init__(self,in_channels,out_channels,hidden_dim):
super(BasicBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=hidden_dim,kernel_size=3,padding=1)
self.conv2 = nn.Conv2d(in_channels=hidden_dim,out_channels=out_channels,kernel_size=3,padding=1)
self.relu = nn.ReLU()
def forward(self,x):
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.relu(x)
return x
class Encoder(nn.Module):
def __init__(self):
super(Encoder,self).__init__()
self.conv1 = BasicBlock(in_channels=1, out_channels=16, hidden_dim=16)
self.conv2 = BasicBlock(in_channels=16, out_channels=8, hidden_dim=8)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self,x):
x = self.conv1(x)
x = self.pool(x)
x = self.conv2(x)
x = self.pool(x)
return x
class Decoder(nn.Module):
def __init__(self):
super(Decoder,self).__init__()
self.conv1 = BasicBlock(in_channels=8,out_channels=8, hidden_dim=8)
self.conv2 = BasicBlock(in_channels=8, out_channels=16, hidden_dim=16)
self.conv3 = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding=1)
self.upsample1 = nn.ConvTranspose2d(8,8,kernel_size=2,stride=2)
self.upsample2 = nn.ConvTranspose2d(16,16,kernel_size=2,stride=2)
def forward(self,x):
x = self.conv1(x)
x = self.upsample1(x)
x = self.conv2(x)
x = self.upsample2(x)
x = self.conv3(x)
return x
class CAE(nn.Module):
def __init__(self):
super(CAE,self).__init__()
self.enc = Encoder()
self.dec = Decoder()
def forward(self,x):
x = self.enc(x)
x = self.dec(x)
return x
if __name__ == "__main__":
##DATA 준비
images, labels = load_image()
train_images, test_images, train_labels, test_labels = train_test_split(images,labels,test_size=0.2,random_state=777)
train_dataset = dataset(images=train_images, labels=train_labels)
test_dataset = dataset(images=test_images, labels=test_labels)
train_dataloader = DataLoader(train_dataset,shuffle=True,batch_size=32)
test_dataloader = DataLoader(test_dataset,shuffle=True,batch_size=1)
##DEVICE 설정
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
##모델 정의
CAE_model = CAE().to(device=device)
##하이퍼 파라미터 설정
lr = 0.001
epoch = 30
optim = Adam(params=CAE_model.parameters(),lr=lr)
criterion = nn.MSELoss()
save_path = "C:/Users/PC_1M/Desktop/코딩/딥러닝 알고리즘/CAE_노이즈제거/model.pt"
signal = input(str("train : y test : n --> "))
if signal == "y":
##학습
for i in range(epoch):
epoch_loss = 0
iterator = tqdm.tqdm(train_dataloader)
for image, label in iterator:
optim.zero_grad()
pred = CAE_model(image.to(device=device))
loss = criterion(pred,label.to(device=device))
loss.backward()
optim.step()
batch_loss = loss.item()
epoch_loss += batch_loss
avg_epoch_loss = epoch_loss / len(train_dataloader)
iterator.set_description(f"epoch{i+1}, loss:{avg_epoch_loss}")
print(iterator)
torch.save(CAE_model.state_dict(),save_path)
elif signal == "n":
with torch.no_grad():
iterator = tqdm.tqdm(test_dataloader)
CAE_model.load_state_dict(torch.load(save_path,map_location=device))
image, label = next(iter(test_dataloader))
pred = CAE_model(image.to(device=device))
label = torch.squeeze(label)
plt.subplot(1,3,1)
plt.imshow(label)
noise_image = torch.squeeze(image)
plt.subplot(1,3,2)
plt.imshow(noise_image)
denoise_image = torch.squeeze(pred).detach().cpu()
plt.subplot(1,3,3)
plt.imshow(denoise_image)
plt.show()