
이 글은 Google Colab 환경에서 HuggingFace의 Vision Transformer(ViT)를 사용하여 이미지 분류 작업을 수행하는 전체 실습 흐름을 담고 있어요. 데이터 전처리부터 학습, 평가, 추론까지 모든 코드를 단계별로 정리했습니다.
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/딥러닝 수업
!pip install -q transformers datasets
!pip install -q --upgrade datasets
from datasets import load_dataset
dataset = load_dataset("./data/data")
import numpy as np
import random
import torch
seed = 999
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
from transformers import AutoImageProcessor
model_name = 'google/vit-base-patch16-224-in21k'
feature_extractor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
def preprocess_image(examples):
images = [feature_extractor(image.convert("RGB")) for image in examples['image']]
examples['pixel_values'] = [image['pixel_values'] for image in images]
return examples
dataset = dataset.map(preprocess_image, batched=True)
dataset.set_format(type='torch', columns=['image', 'label', 'pixel_values'])
train_val_dataset = dataset['train'].train_test_split(test_size=0.2)
train_dataset = train_val_dataset['train']
val_dataset = train_val_dataset['test']
num_labels = len(dataset['train'].features['label'].names)
from transformers import ViTForImageClassification
model = ViTForImageClassification.from_pretrained(model_name, num_labels=num_labels)
import os
output_dir = './result/vit_fruit_classfication'
log_dir = './log/vit_fruit_classfication'
os.makedirs(output_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir=output_dir,
logging_dir=log_dir,
learning_rate=2e-4,
weight_decay=0.01,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
)
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
def compute_metrics(eval_pred):
logits, labels = eval_pred
preds = np.argmax(logits, axis=-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
acc = accuracy_score(labels, preds)
return {'accuracy': acc, 'f1': f1, 'precision': precision, 'recall': recall}
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics
)
trainer.train()
cls_dir = './model/vit_fruit_classfication'
os.makedirs(cls_dir, exist_ok=True)
trainer.save_model(cls_dir)
model.save_pretrained(cls_dir)
feature_extractor.save_pretrained(cls_dir)
from transformers import pipeline
image_classifier = pipeline('image-classification', model=cls_dir, feature_extractor=cls_dir)
image = dataset['test']['image'][10]
image = image.to(torch.uint8)
from torchvision.transforms import functional as TF
image = TF.to_pil_image(image)
pred = image_classifier(image)
print(pred)
print(dataset['test']['label'][10])
from transformers import AutoImageProcessor, ViTForImageClassification
model = ViTForImageClassification.from_pretrained(cls_dir)
processor = AutoImageProcessor.from_pretrained(cls_dir)
idx = 1
image = dataset["test"]["image"][idx].to(torch.uint8)
image = TF.to_pil_image(image)
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = logits.softmax(dim=-1)[0]
pred_idx = probs.argmax().item()
pred_score = probs[pred_idx].item()
print(f"Predicted class: {pred_idx}, confidence={pred_score:.4f}")
print(dataset["test"]["label"][idx])
from transformers import AutoImageProcessor, ViTForImageClassification
model = ViTForImageClassification.from_pretrained(cls_dir)
feature_extractor = AutoImageProcessor.from_pretrained(cls_dir)
import matplotlib.pyplot as plt
image = dataset["test"]["image"][idx]
plt.imshow(image.permute(1, 2, 0)) # shape: C, H, W → H, W, C
plt.title(f"Predicted: {pred_idx}, True: {dataset['test']['label'][idx]}")
plt.axis('off')
plt.show()
google/vit-base-patch16-224 외에 다른 사전학습 모델과 성능 비교 가능Colab 환경에서 ViT 모델을 활용한 이미지 분류 실습을 진행하며,
까지의 전체 흐름을 경험했어요.
딥러닝 기반 이미지 분류를 처음 실습하는 분들에게 좋은 입문 가이드가 되었길 바라요
작성일 : 2025.07.02
작성자 : 발라