์ง๋ ์๊ฐ์ ์ด์ด, Dataset / DataLoader class๋ฅผ ํ์ฉํ์ฌ ๊ฐ์์ง ๋ถ๋ฅ ๋ชจ๋ธ์ ์์ฑํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค! ๐
โจ ์ ๊ฐ์์ง๋๊ตฌ์? ๐ถ ๊ฐ์์ง๋ ๊ท์ฌ์ฐ๋๊น์ :)
๐ ๊ทธ๋ผ, LET'S DIGGIN' !
โ Data๋ Kaggle์์ Stanford Dog Dataset์ ์ค๋นํ์ต๋๋ค.
โ https://www.kaggle.com/datasets/jessicali9530/stanford-dogs-dataset
โ train : val = 0.85 : 0.15 split ์ํํ์์ต๋๋ค.
import os
import shutil
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
org_image_path = os.path.join(root_path, "archive/images/Images/")
labels = os.listdir(org_image_path)
labels_cnt_list = []
for l in labels:
labels_cnt_list.append(len(os.listdir(os.path.join(org_image_path, l))))
os.makedirs(os.path.join(root_path, 'data/'), exist_ok=True)
os.makedirs(os.path.join(root_path, 'data/train/'), exist_ok=True)
os.makedirs(os.path.join(root_path, 'data/val/'), exist_ok=True)
for l in labels:
os.makedirs(os.path.join(root_path, 'data/train/', l), exist_ok=True)
os.makedirs(os.path.join(root_path, 'data/val/', l), exist_ok=True)
train_img_list = []
train_label_list = []
val_img_list = []
val_label_list = []
for idx, l in enumerate(labels):
num_train = int(labels_cnt_list[idx] * 0.85)
tmp_image_name_list = os.listdir(os.path.join(org_image_path, l))
for cnt, fname in enumerate(tmp_image_name_list):
if cnt <= num_train:
dst_path = os.path.join(root_path, 'data/train/', l, fname)
shutil.copy(os.path.join(org_image_path, l, fname), dst_path)
train_img_list.append(dst_path)
train_label_list.append(idx)
else:
dst_path = os.path.join(root_path, 'data/val/', l, fname)
shutil.copy(os.path.join(org_image_path, l, fname), dst_path)
val_img_list.append(dst_path)
val_label_list.append(idx)
โ Vit : Vision Transformer!
โ Computer Vision Task์์ ํญ์ ๋น ์ง์ง ์๋ CNN ์ํคํ
์ฒ๋ฅผ ์ ์ธํ๊ณ , ์ค์ง Self-attention๋ง์ ์ฌ์ฉํ์ฌ์๋ ์ถฉ๋ถํ CV Task๋ฅผ ์ํํ ์ ์์์ ๋ณด์ฌ์ค ๋
ผ๋ฌธ์
๋๋ค :)
โ pytorch์ ๊ตฌํ์ฒด๊ฐ ์ค๋น๋์ด ์์ด, ๊ทธ๋๋ก ํ์ฉํ๋ฉด ๋ฉ๋๋ค.
โจ timm (PyTorch Image Models) ํจํค์ง๋ฅผ ์ฌ์ฉํ ์์ ์
๋๋ค!
๐ timm ํจํค์ง๋ ๋์ ์ฑ๋ฅ์ ๋ณด์ด๋ Computer Vision D/L ์๊ณ ๋ฆฌ์ฆ๋ค์ ๋ฏธ๋ฆฌ pytorch๋ก ๊ตฌํํด ๋์ ํจํค์ง๋ก์, ๊ฐ๋จํ๊ณ ๋น ๋ฅด๊ฒ ๋ชจ๋ธ์ ๊ตฌํํ ์ ์๊ฒ ๋์์ค๋๋ค :)
pip install timm
์ง๋ ์๊ฐ์ ๊ตฌํํ์๋ Dataset๊ณผ DataLoader class๋ฅผ ํ์ฉํ๊ฒ ์ต๋๋ค. ๐
import torch
from PIL import Image
class MyDataset(torch.utils.data.Dataset):
"""
Attributes
----------
img_list : ๋ฆฌ์คํธ
์ด๋ฏธ์ง์ ๊ฒฝ๋ก๋ฅผ ์ ์ฅํ ๋ฆฌ์คํธ
label_list : ๋ฆฌ์คํธ
label์ ๊ฒฝ๋ก๋ฅผ ์ ์ฅํ ๋ฆฌ์คํธ
phase : 'train' or 'val'
ํ์ต ๋๋ ํ
์คํธ ์ฌ๋ถ ๊ฒฐ์
transform : object
์ ์ฒ๋ฆฌ ํด๋์ค์ ์ธ์คํด์ค
"""
def __init__(self, img_list, label_list, phase, transform):
self.img_list = img_list
self.label_list = label_list
self.phase = phase # train ๋๋ val์ ์ง์
self.transform = transform # ์ด๋ฏธ์ง์ ๋ณํ
def __len__(self):
'''์ด๋ฏธ์ง์ ๊ฐฏ์๋ฅผ ๋ฐํ'''
return len(self.img_list)
def __getitem__(self, index):
'''
์ ์ฒ๋ฆฌํ ์ด๋ฏธ์ง ๋ฐ ๋ผ๋ฒจ return
'''
# img_path = self.img_list[index]
# img = Image.open(img_path).convert('RGB')
img = self.img_list[index]
transformed_img = self.transform(img, self.phase)
label = self.label_list[index]
return transformed_img, label
from torchvision import models, transforms
class MyTransform():
"""
Attributes
----------
resize : int
Transform ์ํ ํ ๋ณ๊ฒฝ๋ width / height ๊ฐ.
mean : (R, G, B)
๊ฐ ์์ ์ฑ๋์ ํ๊ท ๊ฐ.
std : (R, G, B)
๊ฐ ์์ ์ฑ๋์ ํ์ค ํธ์ฐจ.
"""
def __init__(self, resize, mean, std):
self.data_transform = {
'train': transforms.Compose([
transforms.RandomResizedCrop(
(resize, resize), scale=(0.5, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # ํ
์๋ก ๋ณํ
transforms.Normalize(mean, std) # ํ์คํ
]),
'val': transforms.Compose([
transforms.Resize((resize, resize)),
transforms.ToTensor(), # ํ
์๋ก ๋ณํ
transforms.Normalize(mean, std) # ํ์คํ
])
}
def __call__(self, img, phase='train'):
"""
Parameters
----------
phase : 'train' or 'val'
์ ์ฒ๋ฆฌ ๋ชจ๋๋ฅผ ์ง์ .
"""
return self.data_transform[phase](img)
๐ ์ข์ต๋๋ค! ์ด์ ์ฌ์ ์ ์ธํด์ผ ํ ํด๋์ค๋ ๋ชจ๋ ์ ์ธํ์์ต๋๋ค.
๐ ๊ทธ๋ฌ๋ฉด, ํ์ต์ ์ํํ๋ ์ฝ๋๋ฅผ ์์ฑํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค.
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
train_dataset = MyDataset(img_list=train_img_list, label_list=train_label_list, phase="train", transform=MyTransform(
size, mean, std)))
val_dataset = MyDataset(img_list=val_img_list, label_list=val_label_list, phase="val", transform=MyTransform(
size, mean, std)))
image_datasets = {'train' : train_dataset, 'val' : val_dataset}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, shuffle=False)
# ์ฌ์ ๊ฐ์ฒด์ ์ ๋ฆฌ
dataloaders_dict = {"train": train_dataloader, "val": val_dataloader}
import timm
num_classes = 120
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes)
import torch.optims as optims
import torch.nn as nn
from torch.optim import lr_scheduler
citerion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
์, ์ง์ง ํ์ต ์์์ ๋๋ค! :)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
epochs = 10
for epoch in range(epochs):
print("{}/{} epoch running now".format(epoch, epochs - 1))
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
# requires_grad = True, when Training
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
# ์์ค ๊ณ์ฐ (๊ฐ Tensor์ gradient ๊ณ์ฐ)
loss.backward()
# ๊ณ์ฐ๋ ์์ค์ optimizer์ ๊ณ์ฐ์ ๋ฐ๋ผ weight ์กฐ์
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects = torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# ๋ชจ๋ธ ์ ์ฅ
if phase == 'val' and epoch_acc > best_acc :
best_acc = epoch_acc
torch.save('./best_model.pth')
๐ ์ด๋ก์จ ๊ฐ์์ง ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ์ฌ ViT ๋ชจ๋ธ์ Transfer Learning์ ๊ฐ๋จํ ์ค์ตํ๋ ์ฝ๋๋ฅผ ์์ฑํด ๋ณด์์ต๋๋ค.
๐ ํ์ง๋ง ๋
ผ๋ฌธ์ ๋ช
์๋ Optimizer๋, learning rate scheduler๋ ๋ค๋ฅธ ๋ถ๋ถ์ด ์๊ธฐ์, ์ด ๋ถ๋ถ์ ์ถํ ๊ฐ์ ์ฌํญ์ด ๋๊ฒ ๊ตฐ์ :)
https://tutorials.pytorch.kr/beginner/transfer_learning_tutorial.html
๐ ์๊ฐํ ์ฝ๋๋ ์์ ํํ ๋ฆฌ์ผ ํ์ด์ง์ ํจ์๋ฅผ ์ฌ์ฉํ์์ต๋๋ค.
def visualize_model(model, num_images=6):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure()
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloaders['val']):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
for j in range(inputs.size()[0]):
images_so_far += 1
ax = plt.subplot(num_images//2, 2, images_so_far)
ax.axis('off')
ax.set_title('predicted: {}'.format(class_names[preds[j]]))
imshow(inputs.cpu().data[j])
if images_so_far == num_images:
model.train(mode=was_training)
return
model.train(mode=was_training)
๐ ์๋์ ๊ฐ์ด ์ฌ์ฉํ๋ฉด ๋ฉ๋๋ค.
visualize_model(model)
๐ WOW! ๊ฐ๋์ ์ด๊ฒ๋ ๋ฐ์ด๋ ์ฑ๋ฅ์ ๋ณด์ฌ์ค ๋ชจ๋ธ์ ๊ฐ์ ธ์์ Transfer Learning์ ์ํํ๋ ๋ฐ ์ฑ๊ณตํ์์ต๋๋ค!
๐ ๋ฌผ๋ก , ๋
ผ๋ฌธ์ ๋ฒค์น๋งํฌ๋ฅผ ์ฌํํ๋ ค๋ฉด ๋
ผ๋ฌธ ๊ทธ๋๋ก์ Training sceinaro์ Optimizer, ๊ทธ๋ฆฌ๊ณ learning rate scheduling ๋ฑ์ด ํฌํจ๋์ด์ผ ํ์ง๋ง, ์ผ๋จ ํด๋ธ๊ฒ ์ด๋์์ :)
๐ ๋ค์ ๊ธฐํ์๋ ์ข๋ Advanced ํ ๊ตฌํ์ผ๋ก ๋์ ํด๋ณด๋๋ก ํ๊ฒ ์ต๋๋ค!