DeiT small 모델을 학습한 결과, 검증셋에 대한 F1-score가 꽤 높았습니다.
이 학습된 모델을 teacher model로 사용하여 student model을 knowledge distillation loss로 학습시키겠습니다.
teacher model의 지식을 전달할 student model은 아주 단순하게 정의해보았습니다.
class Student(torch.nn.Module):
def __init__(self):
super().__init__()
self.seq = torch.nn.Sequential(
torch.nn.Conv2d(in_channels = 3 ,out_channels=768, kernel_size = 3 ),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Flatten(),
torch.nn.Linear(9462528, 159),)
def forward(self, x):
return self.seq(x)
from transformers import DeiTFeatureExtractor, DeiTForImageClassification, DeiTConfig
extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-small-distilled-patch16-224')
deit_model = DeiTForImageClassification.from_pretrained( 'facebook/deit-small-distilled-patch16-224')
deit_model.classifier.out_features = 159
class Teacher(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.out = torch.nn.Linear(1000,159)
def forward(self, x):
x = self.model(x)
x = self.out(x.logits)
return x
# Assign the previously trained weights
deit_model.load_state_dict(torch.load("/content/drive/MyDrive/vision/OCR이미지분류/face_book_Levit_fold_0_epoch_9.pth"))
# load the teacher model
teacher = Teacher(deit_model).to(device)
teacher.eval()
criterion = torch.nn.KLDivLoss()
def distillation( criterion, student_pred, teacher_pred):
return criterion(torch.nn.functional.log_softmax(student_pred / 2.0, dim=1),
torch.nn.functional.softmax(teacher_pred / 2.0, dim=1))
def train_loop(dataloader, student, teacher, distillation, criterion, optimizer,device):
epoch_loss = 0
student.train() # student.train()
teacher.eval() # teacher.eval()
for batch in tqdm(dataloader):
student_pred = student(batch["img"].to(device))
teacher_pred = teacher(batch["img"].to(device))
# distillation loss 학습
loss = distillation(criterion, student_pred, teacher_pred)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_loss /= len(dataloader)
return epoch_loss
@torch.no_grad()
def test_loop(dataloader,model,loss_fn,device):
epoch_loss = 0
model.eval()
pred_list = []
true_list = []
softmax = torch.nn.Softmax(dim=1)
for batch in tqdm(dataloader):
pred = model(batch["img"].to(device))
if batch.get("y") is not None:
# 일반 cross entropy loss 사용
loss = loss_fn(pred, batch["y"].to(device))
epoch_loss += loss.item()
pred = softmax(pred)
pred = pred.to("cpu").detach().numpy()
true = batch['y'].to('cpu').numpy()
pred_list.append(pred)
true_list.append(true)
epoch_loss /= len(dataloader)
pred = np.concatenate(pred_list)
true = np.concatenate(true_list)
return epoch_loss , pred , true
loss_fn = torch.nn.CrossEntropyLoss()
batch_size = 16
criterion = torch.nn.KLDivLoss()
teacher = teacher.to(device)
student = student.to(device)
for i,(tri,vai) in enumerate(cv.split(train)):
if i == 0 :
optimizer = torch.optim.RAdam(student.parameters(), lr = 0.00001)
train_dt = Dataset(train[tri],target[tri], extractor)
valid_dt = Dataset(train[vai],target[vai], extractor)
train_dl = torch.utils.data.DataLoader(train_dt, batch_size=batch_size, shuffle=True, collate_fn = collate_fn)
valid_dl = torch.utils.data.DataLoader(valid_dt, batch_size=batch_size,shuffle=False, collate_fn = collate_fn)
best_score = 0
patience = 0
best_score_list = []
num_epochs = 10
for epoch in range(num_epochs):
train_loss = train_loop(train_dl, student, teacher , distillation, criterion, optimizer, device)
valid_loss , pred , true = test_loop(valid_dl, student , loss_fn, device )
pred = np.argmax(pred, axis=1)
score = f1_score(true, pred , average="weighted")
print(f"distillation loss {train_loss}, valid loss : {valid_loss} , f1-score : {score}")
patience += 1
if best_score < score:
patience = 0
best_score = score
torch.save(student.state_dict(), f"/content/drive/MyDrive/vision/OCR이미지분류/Distillation_Deit_and_CNN_0fold_epoch_{epoch}.pth")
if patience == 3:
break
print(f" Epoch ({epoch}), BEST F1: {best_score}")
print(f"Fold ({i}), BEST F1: {best_score}")
torch.cuda.empty_cache()
oom에러로 인한 학습 중단으로 결과 확인 불가. 추후 업로드 예정