[Pytorch] 이미지 분류 with ViT

Mollang·2023년 3월 22일
0

AI hub에서 제공하는 이미지 데이터셋을 활용하여 이미지 분류 태스크를 수행한다. VGG16으로 테스트하였을 때 망을 몇 개 더 쌓았더니 성능이 좋지 않아 중간에 중단하였다. 대안으로 ViT 모델을 테스트하려 한다.

Vision Transformer(ViT) 모델은 자연어 처리에 사용되는 Transformer를 이미지에 적용한 모델인데, NLP에서 텍스트 데이터를 토큰으로 다루 듯 이미지를 패치 조각으로 분리한다. ViT모델이 중간 크기의 데이터셋보다는 더 큰 규모의 데이터셋으로 학습하였을 때 더 좋은 성능이 나온다는 후기가 있다. 현재 보유하고 데이터셋은 크기가 작을 뿐더러 1주일 안에 모델 성능을 올려야 하기 때문에, 이 모델로 테스트를 하는 것이 괜찮은 선택인지 우려된다.

일단 결과를 봐야 알 것 같다.

사전학습 모델은 환경 및 장비에 따라 학습에 소요되는 시간이 "정말 짧으면" 6시간 내외이며 "최악으로 오래 걸릴 경우" 며칠 이상이 소요되기도 한다. 보통 자연어 처리 태스크를 위해 사전학습 모델을 파인튜닝하여 모델 학습을 진행하였는데, RAM 부족 문제로 인해 학습 중단 에러가 빈번히 발생했다. 두 개의 코랩 프로 계정으로 동시에 학습하였는데도 모델 하나를 온전히 돌리지도 못한 적이 많다.

  • 코랩 학습 환경 : GPU > 프리미엄 > 고용량 RAM

데이터셋 클래스 정의

class Dataset(torch.utils.data.Dataset) :
  def __init__(self, transform  , img_path_np, target=None ):
    self.imgs = img_path_np 
    self.target = target
    self.transform = transform
  
  def __len__(self):
    return len(self.imgs)

  def __getitem__(self, idx):
    item = {}
    file_path = self.imgs[idx] 
    img = cv2.imread(file_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  
    item['img'] = self.transform(image = img)['image']
    item['y'] = self.target[idx]  # len 159
    return item

# transformer 
transformer = albumentations.Compose([
    albumentations.Normalize(),
    albumentations.Resize(height = 224, width = 224),
    albumentations.pytorch.transforms.ToTensorV2(),
])


# transformer oneof > albumentations 
transform_oneof = albumentations.Compose([
    albumentations.Normalize(),
    albumentations.Resize(224, 224), 
    albumentations.OneOf([
                          albumentations.MotionBlur(p=1),
                          albumentations.OpticalDistortion(p=1),
                          albumentations.GaussNoise(p=1)                 
    ], p=1),
    albumentations.OneOf([
                          albumentations.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                          albumentations.OpticalDistortion(p=1),
                          albumentations.GaussNoise(p=1)                 
    ], p=1),
    albumentations.pytorch.transforms.ToTensorV2(),
])

train & valid loop

def train_loop(dataloader,model,loss_fn,optimizer,device):
    epoch_loss = 0 
    model.train() 
    for batch in tqdm(dataloader): 
        pred = model(batch["img"].to(device)).to(device)
        loss = loss_fn(pred, batch["y"].to(device))   
        optimizer.zero_grad() 
        loss.backward()  
        optimizer.step() 
        
        epoch_loss += loss.item() 

    epoch_loss /= len(dataloader) 

    return epoch_loss


@torch.no_grad() 
def test_loop(dataloader,model,loss_fn,device): 
    epoch_loss = 0
    model.eval() 
    
    pred_list = []
    true_list = []
    softmax = torch.nn.Softmax(dim=1) 

    for batch in tqdm(dataloader):   
        pred = model(batch["img"].to(device)).to(device)
        
        if batch.get("y") is not None: 
            loss = loss_fn(pred, batch["y"].to(device))
            epoch_loss += loss.item()
        
        pred = softmax(pred)
        pred = pred.to("cpu").numpy() 
        true = batch['y'].to('cpu').numpy()

        pred_list.append(pred)
        true_list.append(true)

    epoch_loss /= len(dataloader)

    pred = np.concatenate(pred_list) 
    true = np.concatenate(true_list)
    return epoch_loss , pred , true

model load

이미지를 159개의 클래스로 분류하는 태스크이다.

!pip install timm

import timm 

num_classes = 159
VIT = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=num_classes).to(device)

train code


seed_everything(seed)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.00001)
loss_fn = torch.nn.CrossEntropyLoss() 
train_x, valid_x, train_y, valid_y = train_test_split(train , target, test_size=0.2, random_state=77)
data_train = Dataset(transform_oneof, train_x, train_y)
data_test = Dataset(transform_oneof, valid_x, valid_y)
train_dl = torch.utils.data.DataLoader(data_train, batch_size = 8, shuffle = True)
test_dl = torch.utils.data.DataLoader(data_test, batch_size = 8, shuffle = False)

best_score = 0
patience = 0
best_score_list = []
num_epochs = 100 
model = VIT.to(device)

for epoch in range(num_epochs):
    train_loss = train_loop(train_dl, model , loss_fn,optimizer,device)
    valid_loss , pred , true = test_loop(test_dl, model , loss_fn,device  )      
    pred = np.argmax(pred, axis=1) 
    score = f1_score(true, pred , average="weighted")
    print(f"train loss {train_loss},  valid loss : {valid_loss} ,  f1-score : {score}")
    patience += 1
    if best_score < score:
        patience = 0
        best_score = score
        torch.save(model.state_dict(), f"/content/drive/MyDrive/vision/OCR이미지분류/google_base_vit_net_{epoch}.pth")

    if patience == 3:
        break

    print(f" Epoch ({epoch}), BEST F1: {best_score}")
    best_score_list.append(best_score)
    torch.cuda.empty_cache()

이미지 3399개를 1번 '학습'만 하는 데 최소 4시간은 걸릴 것 같다. 내일 오전 중으로 학습이 완료되지 않으면 중단하고 다른 모델 학습을 시도할 예정이다.

0개의 댓글