[Pytorch] Knowledge Distillation with DeiT small

Mollang·2023년 4월 5일
0
post-thumbnail

DeiT small 모델을 학습한 결과, 검증셋에 대한 F1-score가 꽤 높았습니다.

  • 데이터셋을 세 개의 폴드로 분리하여 각각 10 에폭 학습
  • 첫 번째 폴드 마지막 에폭에서 f1-score 0.9139를 기록
  • 꾸준한 loss 감소 , score 상승

이 학습된 모델을 teacher model로 사용하여 student model을 knowledge distillation loss로 학습시키겠습니다.

Define student model

teacher model의 지식을 전달할 student model은 아주 단순하게 정의해보았습니다.

class Student(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.seq = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels = 3 ,out_channels=768, kernel_size = 3 ),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2),
        torch.nn.Flatten(),
        torch.nn.Linear(9462528, 159),)
  
  def forward(self, x):
    return self.seq(x)

Load the teacher model

from transformers import DeiTFeatureExtractor, DeiTForImageClassification, DeiTConfig
extractor = DeiTFeatureExtractor.from_pretrained('facebook/deit-small-distilled-patch16-224')
deit_model = DeiTForImageClassification.from_pretrained( 'facebook/deit-small-distilled-patch16-224')
deit_model.classifier.out_features = 159
class Teacher(torch.nn.Module):
  def __init__(self, model):
    super().__init__()
    self.model = model
    self.out = torch.nn.Linear(1000,159)
  
  def forward(self, x):
    x = self.model(x)
    x = self.out(x.logits)
    return x
# Assign the previously trained weights
deit_model.load_state_dict(torch.load("/content/drive/MyDrive/vision/OCR이미지분류/face_book_Levit_fold_0_epoch_9.pth"))
# load the teacher model 
teacher = Teacher(deit_model).to(device)
teacher.eval()

Define knowledge distillation loss

criterion = torch.nn.KLDivLoss()
def distillation( criterion, student_pred, teacher_pred):
  return criterion(torch.nn.functional.log_softmax(student_pred / 2.0, dim=1),
                         torch.nn.functional.softmax(teacher_pred / 2.0, dim=1))

train loop code

def train_loop(dataloader, student, teacher, distillation, criterion, optimizer,device):
    epoch_loss = 0 
    student.train() # student.train()
    teacher.eval() # teacher.eval()

    for batch in tqdm(dataloader): 
        student_pred = student(batch["img"].to(device))
        teacher_pred = teacher(batch["img"].to(device)) 
        
        # distillation loss 학습
        loss = distillation(criterion, student_pred, teacher_pred)
        
        optimizer.zero_grad() 
        loss.backward()  
        optimizer.step() 
        
        epoch_loss += loss.item() 

    epoch_loss /= len(dataloader) 

    return epoch_loss

valid loop code

@torch.no_grad() 
def test_loop(dataloader,model,loss_fn,device): 
    epoch_loss = 0
    model.eval() 
    
    pred_list = []
    true_list = []
    softmax = torch.nn.Softmax(dim=1) 

    for batch in tqdm(dataloader):   
        pred = model(batch["img"].to(device))
        
        if batch.get("y") is not None: 
        	# 일반 cross entropy loss 사용
            loss = loss_fn(pred, batch["y"].to(device))
            epoch_loss += loss.item()
        
        pred = softmax(pred)
        pred = pred.to("cpu").detach().numpy() 
        true = batch['y'].to('cpu').numpy()

        pred_list.append(pred)
        true_list.append(true)

    epoch_loss /= len(dataloader)

    pred = np.concatenate(pred_list) 
    true = np.concatenate(true_list)
    return epoch_loss , pred , true

train

loss_fn = torch.nn.CrossEntropyLoss() 
batch_size = 16
criterion = torch.nn.KLDivLoss()
teacher = teacher.to(device)
student = student.to(device)
for i,(tri,vai) in enumerate(cv.split(train)):
  if i == 0 : 
      optimizer = torch.optim.RAdam(student.parameters(), lr = 0.00001)
      train_dt = Dataset(train[tri],target[tri], extractor)
      valid_dt = Dataset(train[vai],target[vai], extractor)
      train_dl = torch.utils.data.DataLoader(train_dt, batch_size=batch_size, shuffle=True, collate_fn = collate_fn)
      valid_dl = torch.utils.data.DataLoader(valid_dt, batch_size=batch_size,shuffle=False, collate_fn = collate_fn)


      best_score = 0
      patience = 0
      best_score_list = []
      num_epochs = 10
      for epoch in range(num_epochs):
          train_loss = train_loop(train_dl, student, teacher , distillation, criterion, optimizer, device)
          valid_loss , pred , true = test_loop(valid_dl, student , loss_fn, device  )      
          pred = np.argmax(pred, axis=1) 
          score = f1_score(true, pred , average="weighted")
          print(f"distillation loss {train_loss},  valid loss : {valid_loss} ,  f1-score : {score}")
          patience += 1
          if best_score < score:
              patience = 0 
              best_score = score
              torch.save(student.state_dict(), f"/content/drive/MyDrive/vision/OCR이미지분류/Distillation_Deit_and_CNN_0fold_epoch_{epoch}.pth")

          if patience == 3:
              break
          print(f" Epoch ({epoch}), BEST F1: {best_score}")

      print(f"Fold ({i}), BEST F1: {best_score}")
      torch.cuda.empty_cache()
      

oom에러로 인한 학습 중단으로 결과 확인 불가. 추후 업로드 예정

0개의 댓글