AlexNet으로 이미지 분류하기

choonsikmom·2022년 6월 8일
0

pytorch

목록 보기
9/14
post-custom-banner

지난 포스팅(LesNet-5)에 이어서 AlexNet model로 image classification을 진행해보고자 한다. 데이터와 진행 절차는 동일하며, 모델만 갈아 끼웠다고 생각하면 된다.

👀 AlexNet

합성곱 층 총 5개와 완전연결층 3개로 구성되어 있으며,
마지막 완결연결층은 카테고리 1000개를 분류하기 위한 소프트맥스 활성화 함수를 사용하고 있다.

계층 유형특성 맵크기커널 크기스트라이드활성화 함수
이미지1227X227---
합성곱층9655X5511X114ReLU
최대 풀링층9627X273X32-
합성곱층25627X275X51ReLU
최대 풀링층25613X133X32-
합성곱층38413X133X31ReLU
합성곱층38413X133X31ReLU
합성곱층25613X133X31ReLU
최대 풀링층2566X63X32-
완전연결층-4096--ReLU
완전연결층-4096--ReLU
완전연결층-1000--Softmax

네트워크에는 학습 가능한 변수가 6600만개 있다.
네트워크에 대한 입력은 227X227X3 크기의 RGB 이미지이며, 각 클래스(카테고리)에 해당하는 1000X1 확률벡터를 출력한다.

그러면 AlexNet를 활용한 개, 고양이 이미지 분류를 수행해 보자.


Image classification with AlexNet

# import library
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torch.autograd import Variable
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
import os
import cv2
import random
import time
from PIL import Image
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as plt
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# data preprocessing
class ImageTransform() :
    def __init__(self, resize, mean, std) :
        self.data_transform = {
            'train' : transforms.Compose([
                    transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std)
            ]),
            'val' : transforms.Compose([
                    transforms.Resize(256),
                    transforms.CenterCrop(resize),
                    transforms.ToTensor(),
                    transforms.Normalize(mean, std)
            ])
        }
        
    def __call__(self, img, phase) :
        return self.data_transform[phase](img)
# data load & split
cat_dir = '../080289-main/chap06/data/dogs-vs-cats/Cat/'
dog_dir = '../080289-main/chap06/data/dogs-vs-cats/Dog/'

cat_images_filepath = sorted([os.path.join(cat_dir, f) for f in os.listdir(cat_dir)])
dog_images_filepath = sorted([os.path.join(dog_dir, f) for f in os.listdir(dog_dir)])

images_filepath = [*cat_images_filepath, *dog_images_filepath]
correct_images_filepath = [i for i in images_filepath if cv2.imread(i) is not None]

random.seed(3)
random.shuffle(correct_images_filepath)
train_image_filepaths = correct_images_filepath[:400]
val_image_filepaths = correct_images_filepath[400:-10]
test_image_filepaths = correct_images_filepath[-10:]

print(len(train_image_filepaths), len(val_image_filepaths), len(test_image_filepaths))
# define custom dataset
class DogvsCatDataset(Dataset) :
    def __init__(self, file_list, transform=None, phase='train') :
        self.file_list = file_list
        self.transform = transform
        self.phase = phase
        
    def __len__(self) :
        return len(self.file_list)
    
    def __getitem__(self, idx) :
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img, self.phase)
        
        label = img_path.split('/')[-1].split('.')[0]
        if label == 'dog' :
            label = 1
        elif label == 'cat' :
            label = 0
            
        return img_transformed, label
# define variables

size = 256
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
batch_size = 32
train_dataset = DogvsCatDataset(train_image_filepaths, transform=ImageTransform(size, mean, std),
                               phase='train')
val_dataset = DogvsCatDataset(val_image_filepaths, transform=ImageTransform(size, mean, std),
                             phase='val')
test_dataset = DogvsCatDataset(val_image_filepaths, transform=ImageTransform(size, mean, std),
                            phase='val')

index = 0
print(train_dataset.__getitem__(index)[0].size())
print(train_dataset.__getitem__(index)[1])
# load dataset to memory
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
dataloader_dict = {'train' : train_dataloader, 'val' : val_dataloader}

batch_iterator = iter(train_dataloader)
inputs, label = next(batch_iterator)
print(inputs.size())
print(label)
# define model network
class AlexNet(nn.Module) :
    def __init__(self) -> None :
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 2),)
        
    def forward(self, x :torch.Tensor) -> torch.Tensor :
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x
# create model
model = AlexNet()
model.to(device)

# optimizer, loss function
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# check model Network 
from torchsummary import summary
summary(model, input_size=(3, 256, 256))

# define training function
def train_model(model, dataloader_dict, criterion, optimizer, n_epochs) :
    since = time.time()
    best_acc = 0.0
    
    for epoch in range(n_epochs) :
        print(f'Epoch {epoch+1}/{n_epochs}')
        print('-' * 20)
        
        for phase in ['train', 'val'] :
            if phase == 'train' :
                model.train()
                
            else :
                model.eval()
                
            epoch_loss = 0.0
            epoch_corrects = 0
            
            for inputs, labels in tqdm(dataloader_dict[phase]) :
                inputs = inputs.to(device)
                labels = labels.to(device)
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train') :
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train' :
                        loss.backward()
                        optimizer.step()
                        
                epoch_loss += loss.item() * inputs.size(0)
                epoch_corrects += torch.sum(preds == labels.data)
                
            epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloader_dict[phase].dataset)
            
            print(f'{phase} Loss : {epoch_loss:.4f} Acc : {epoch_acc:.4f}')
        
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed %60:0f}')

    return model
n_epochs = 10
model = train_model(model, dataloader_dict, criterion, optimizer, n_epochs)
# make prediction
import pandas as pd
id_list = []
pred_list = []
_id = 0
with torch.no_grad() :
    for test_path in tqdm(test_image_filepaths) :
        img = Image.open(test_path)
        _id = test_path.split('/')[-1].split('.')[1]
        transform = ImageTransform(size, mean, std)
        img = transform(img, phase='val')
        img = img.unsqueeze(0)
        img = img.to(device)
        
        model.eval()
        outputs = model(img)
        preds = F.softmax(outputs, dim=1)[:, 1].tolist()
        
        id_list.append(_id)
        pred_list.append(preds[0])
        
res = pd.DataFrame({'id' : id_list,
                   'label' : pred_list})
res.to_csv('./alexnet.csv', index=False)
# result visualization
class_ = classes = {0 : 'cat', 1 : 'dog'}

def display_image_grid(images_filepaths, pred_labels=(), cols=5) :
    rows = len(images_filepaths) // cols
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
    for i, image_filepath in enumerate(images_filepaths) :
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        a = random.choice(res['id'].values)
        label = res.loc[res['id'] == a, 'label'].values[0]
        
        if label > 0.5 :
            label = 1
        else :
            label = 0
            
        ax.ravel()[i].imshow(image)
        ax.ravel()[i].set_title(class_[label])
        ax.ravel()[i].set_axis_off()
        
    plt.tight_layout()
    plt.show()
display_image_grid(test_image_filepaths)


📚 reference

  • (길벗) 딥러닝 파이토치 교과서 / 서지영 지음
  • github
profile
춘식이랑 함께하는 개발일지.. 그런데 이제 먼작귀를 곁들인
post-custom-banner

0개의 댓글