[딥러닝-오토엔코더] 오토엔코더를 이용한 노이즈 제거

1
post-thumbnail

06번째 게시물

안녕하세요. 이번 게시물에서는 MNIST 손글씨 데이터셋을 가지고 랜덤한 값으로 노이즈를 인위적으로 섞은 뒤 오토엔코더를 통해 가우스 노이즈를 제거해보려고 합니다.

1. 노이즈 제거 개념

이미지에서 노이즈는 본래의 값이 아니지만 섞이게 되는 값을 의미하며 주로 초점 문제나 통신 과정에서 생기는 가우스 노이즈가 존재합니다.
오토엔코더는 비지도학습의 일종으로 input data가 라벨의 역할을 하게 됩니다. 따라서 input data를 라벨로 입력하고 input data에 가우시안 noise를 섞은 수정된 input을 실제 모델의 input에 입력시켜 줍니다.
그 후 가우스 노이즈가 섞은 이미지를 학습된 모델에 입력시켜 노이즈를 제거한 상태로 복원해보려고 합니다.

2. 가우시안 노이즈 포함 이미지 생성

1) 가우시안 노이즈가 포함된 이미지 로드 함수

def load_image():
    training_data = MNIST(root="./",train=False,download=True,transform=ToTensor())
    labels = MNIST(root="./",train=False,download=True,transform=ToTensor())
    images = []

    for image in training_data:
        noisy_input = gaussian_noise(image[0][0].clone().detach())
        input_tensor = noisy_input.clone().detach()
        images.append(torch.unsqueeze(input_tensor,dim=0))

    return images, labels

def gaussian_noise(x, scale=0.2):
    gaussian_data_x = x+np.random.normal(loc=0,scale=scale,size=x.shape)
    gaussian_data_x = np.clip(gaussian_data_x, 0, 1)
    gaussian_data_x = gaussian_data_x.type(torch.FloatTensor)
    
    return gaussian_data_x

2) 실제 이미지를 load하고 1)에서 만든 함수를 통해 가우시안 노이즈 포함 dataset 생성

if __name__ == "__main__":

##DATA 준비
    images, labels = load_image()
    train_images, test_images, train_labels, test_labels = train_test_split(images,labels,test_size=0.2,random_state=777)

    train_dataset = dataset(images=train_images, labels=train_labels)
    test_dataset = dataset(images=test_images, labels=test_labels)

위와 같이 input data에 가우시안 노이즈를 섞은 데이터셋을 완성하고 input으로 가우시안 노이즈 포함 데이터셋, label로 일반 데이터셋을 오토엔코더에 입력하면 나머지는 04번째 게시물과 동일하게 진행 됩니다.

3. 오토엔코더 테스트 결과



4. 결론

오토 엔코더 모델을 좀 단순하게 만들어서 완벽하게 노이즈가 제거되지는 않았고 압축된 피처맵을 복원할 때에 어느 정도 원본이미지에서 변형이 일어나는 것을 확인할 수 있었습니다.

5. 전체 코드

import tqdm
import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim.adam import Adam
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torch.utils.data.dataset import Dataset
from torchvision.datasets.mnist import MNIST
from sklearn.model_selection import train_test_split

def load_image():
    training_data = MNIST(root="./",train=False,download=True,transform=ToTensor())
    labels = MNIST(root="./",train=False,download=True,transform=ToTensor())
    images = []

    for image in training_data:
        noisy_input = gaussian_noise(image[0][0].clone().detach())
        input_tensor = noisy_input.clone().detach()
        images.append(torch.unsqueeze(input_tensor,dim=0))

    return images, labels

def gaussian_noise(x, scale=0.2):
    gaussian_data_x = x+np.random.normal(loc=0,scale=scale,size=x.shape)
    gaussian_data_x = np.clip(gaussian_data_x, 0, 1)
    gaussian_data_x = gaussian_data_x.type(torch.FloatTensor)
    
    return gaussian_data_x

class dataset(Dataset):
    def __init__(self,images,labels):
        super(dataset,self).__init__()
        self.images = images
        self.labels = labels

    def __len__(self):

        return len(self.labels)

    def __getitem__(self,index):
        image = self.images[index]
        label = self.labels[index][0] / 255
        label = label.type(torch.FloatTensor)

        return image, label

class BasicBlock(nn.Module):

    def __init__(self,in_channels,out_channels,hidden_dim):
        super(BasicBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=hidden_dim,kernel_size=3,padding=1)
        self.conv2 = nn.Conv2d(in_channels=hidden_dim,out_channels=out_channels,kernel_size=3,padding=1)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)

        return x

class Encoder(nn.Module):

    def __init__(self):
        super(Encoder,self).__init__()
        self.conv1 = BasicBlock(in_channels=1, out_channels=16, hidden_dim=16)
        self.conv2 = BasicBlock(in_channels=16, out_channels=8, hidden_dim=8)
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)

        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        self.conv1 = BasicBlock(in_channels=8,out_channels=8, hidden_dim=8)
        self.conv2 = BasicBlock(in_channels=8, out_channels=16, hidden_dim=16)
        self.conv3 = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=3, padding=1)
        self.upsample1 = nn.ConvTranspose2d(8,8,kernel_size=2,stride=2)
        self.upsample2 = nn.ConvTranspose2d(16,16,kernel_size=2,stride=2)

    def forward(self,x):
        x = self.conv1(x)
        x = self.upsample1(x)
        x = self.conv2(x)
        x = self.upsample2(x)
        x = self.conv3(x)

        return x

class CAE(nn.Module):
    def __init__(self):
        super(CAE,self).__init__()
        self.enc = Encoder()
        self.dec = Decoder()

    def forward(self,x):
        x = self.enc(x)
        x = self.dec(x)

        return x

if __name__ == "__main__":

##DATA 준비
    images, labels = load_image()
    train_images, test_images, train_labels, test_labels = train_test_split(images,labels,test_size=0.2,random_state=777)

    train_dataset = dataset(images=train_images, labels=train_labels)
    test_dataset = dataset(images=test_images, labels=test_labels)
    train_dataloader = DataLoader(train_dataset,shuffle=True,batch_size=32)
    test_dataloader = DataLoader(test_dataset,shuffle=True,batch_size=1)

    ##DEVICE 설정
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    ##모델 정의

    CAE_model = CAE().to(device=device)

    ##하이퍼 파라미터 설정

    lr = 0.001

    epoch = 30

    optim = Adam(params=CAE_model.parameters(),lr=lr)

    criterion = nn.MSELoss()

    save_path = "C:/Users/PC_1M/Desktop/코딩/딥러닝 알고리즘/CAE_노이즈제거/model.pt"

    signal = input(str("train : y test : n --> "))

    if signal == "y":

        ##학습
        for i in range(epoch):
            epoch_loss = 0
            iterator = tqdm.tqdm(train_dataloader)

            for image, label in iterator:
                optim.zero_grad()
                pred = CAE_model(image.to(device=device))
                loss = criterion(pred,label.to(device=device))
                loss.backward()
                optim.step()
                batch_loss = loss.item()
                epoch_loss += batch_loss
                avg_epoch_loss = epoch_loss / len(train_dataloader)

            iterator.set_description(f"epoch{i+1}, loss:{avg_epoch_loss}")
            print(iterator)

        torch.save(CAE_model.state_dict(),save_path)

    elif signal == "n":
        with torch.no_grad():
            iterator = tqdm.tqdm(test_dataloader)
            CAE_model.load_state_dict(torch.load(save_path,map_location=device))
            image, label = next(iter(test_dataloader))
            pred = CAE_model(image.to(device=device))
            
            label = torch.squeeze(label)
            plt.subplot(1,3,1)
            plt.imshow(label)

            noise_image = torch.squeeze(image)
            plt.subplot(1,3,2)
            plt.imshow(noise_image)
            
            denoise_image = torch.squeeze(pred).detach().cpu()
            plt.subplot(1,3,3)
            plt.imshow(denoise_image)
            plt.show()
profile
재미로 해보는 다양한 AI프로젝트 모음집

0개의 댓글