ViT 이미지 분류 전체 실습 코드 (Colab)

이 글은 Google Colab 환경에서 HuggingFace의 Vision Transformer(ViT)를 사용하여 이미지 분류 작업을 수행하는 전체 실습 흐름을 담고 있어요. 데이터 전처리부터 학습, 평가, 추론까지 모든 코드를 단계별로 정리했습니다.

1. 구글 드라이브 마운트 및 환경 설정

from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/딥러닝 수업

!pip install -q transformers datasets
!pip install -q --upgrade datasets

2. 데이터 불러오기

from datasets import load_dataset
dataset = load_dataset("./data/data")

3. 시드 고정

import numpy as np
import random
import torch

seed = 999
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

4. 이미지 전처리 및 모델 준비

from transformers import AutoImageProcessor

model_name = 'google/vit-base-patch16-224-in21k'
feature_extractor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)

def preprocess_image(examples):
    images = [feature_extractor(image.convert("RGB")) for image in examples['image']]
    examples['pixel_values'] = [image['pixel_values'] for image in images]
    return examples

dataset = dataset.map(preprocess_image, batched=True)
dataset.set_format(type='torch', columns=['image', 'label', 'pixel_values'])

5. 학습/검증 분리 및 모델 정의

train_val_dataset = dataset['train'].train_test_split(test_size=0.2)
train_dataset = train_val_dataset['train']
val_dataset = train_val_dataset['test']
num_labels = len(dataset['train'].features['label'].names)

from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained(model_name, num_labels=num_labels)

6. 결과 디렉토리 생성

import os

output_dir = './result/vit_fruit_classfication'
log_dir = './log/vit_fruit_classfication'

os.makedirs(output_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)

7. TrainingArguments 설정

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir=output_dir,
    logging_dir=log_dir,
    learning_rate=2e-4,
    weight_decay=0.01,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
)

8. 평가 함수 정의

from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}

9. Trainer 정의 및 훈련

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

10. 모델 저장

cls_dir = './model/vit_fruit_classfication'
os.makedirs(cls_dir, exist_ok=True)
trainer.save_model(cls_dir)
model.save_pretrained(cls_dir)
feature_extractor.save_pretrained(cls_dir)

11. Pipeline으로 예측

from transformers import pipeline

image_classifier = pipeline('image-classification', model=cls_dir, feature_extractor=cls_dir)
image = dataset['test']['image'][10]
image = image.to(torch.uint8)
from torchvision.transforms import functional as TF
image = TF.to_pil_image(image)
pred = image_classifier(image)

print(pred)
print(dataset['test']['label'][10])

12. 수동 예측 (Processor + 모델 직접 사용)

from transformers import AutoImageProcessor, ViTForImageClassification

model = ViTForImageClassification.from_pretrained(cls_dir)
processor = AutoImageProcessor.from_pretrained(cls_dir)

idx = 1
image = dataset["test"]["image"][idx].to(torch.uint8)
image = TF.to_pil_image(image)
inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits

probs = logits.softmax(dim=-1)[0]
pred_idx = probs.argmax().item()
pred_score = probs[pred_idx].item()

print(f"Predicted class: {pred_idx}, confidence={pred_score:.4f}")
print(dataset["test"]["label"][idx])

from transformers import AutoImageProcessor, ViTForImageClassification

model = ViTForImageClassification.from_pretrained(cls_dir)
feature_extractor = AutoImageProcessor.from_pretrained(cls_dir)

결과 요약

  • 정확도(Accuracy): 예시) 94.2%
  • F1 점수(F1-score): 예시) 93.8%
  • Precision / Recall: 클래스별로 균형잡힌 성능 확인
  • 학습 Epoch 수: 3회
  • 훈련/검증 데이터 비율: 8:2

예측 이미지 시각화 예시

import matplotlib.pyplot as plt

image = dataset["test"]["image"][idx]
plt.imshow(image.permute(1, 2, 0))  # shape: C, H, W → H, W, C
plt.title(f"Predicted: {pred_idx}, True: {dataset['test']['label'][idx]}")
plt.axis('off')
plt.show()

한계 및 개선 방향

  • 데이터 수가 적을 경우 과적합 가능성 존재
  • 클래스 간 불균형이 있다면 F1 점수 하락
  • 더 긴 학습이나 augmentation 기법 적용하면 성능 개선 가능
  • google/vit-base-patch16-224 외에 다른 사전학습 모델과 성능 비교 가능

마무리 정리

Colab 환경에서 ViT 모델을 활용한 이미지 분류 실습을 진행하며,

  • 데이터 전처리
  • 모델 구성 및 학습
  • 성능 평가 및 예측

까지의 전체 흐름을 경험했어요.
딥러닝 기반 이미지 분류를 처음 실습하는 분들에게 좋은 입문 가이드가 되었길 바라요


작성일 : 2025.07.02
작성자 : 발라

profile
능숙한 바이브코딩을 할 수 있게 됨을 꿈꾸며

0개의 댓글