논문리뷰 - Masked Autoencoders Are Scalable Vision Learners

Hitdahit·2025년 12월 28일

논문리뷰

목록 보기
11/11

Abstract


  • masked autoencoders (MAE) → expandible self-supervised learning 방법임.
  • MAE 학습방법: input의 random patch에 대해 masking 후 recon하도록 학습한다.
    • 구체적으로 아래와 같은 구조를 제안함.
    1. asymmetric encoder–decoder architecture를 제안
      • encoder는 mask token 없이, 보이는 patch subset에 대해서만 연산을 수행한 후
      • lightweight decoder에서 latent representation과 mask tokens를 이용해 원본을 recon.
    2. 저자들은 입력 이미지의 높은 비율(예: 75%)을 masking하는 것이 의미있는 pre-task임을 확인함.
  • 장점:
    1. 학습 속도를 3배 이상 가속하면서도 정확도를 향상시킴.
    2. Generalize 된 high-capacity 모델을 학습시킬 수 있음,
    3. downstream task로의 transfer performance 우수 (supervised pretraining보다 우수)
      • 추가로 좋은 scaling behavior를 확인하였음

Introduction


  • Model Architecture의 capacity가 폭발적으로 증가하면서, 이제 백만장 쯤은 가볍게 overfit할 수 있음.
  • 이에 적절하게 generalize 된 모델을 학습시키기 위해 수억 장의 이미지가 필요해지기도 함
    • 그러나 이건 publically accessible 하지 않음.
  • 이러한 데이터 부족 문제는 NLP 쪽에서 해결된 사례가 있다.
    • self-supervised pre train을 통해 해결됨.
    • 정확힌 auto-regressive language modeling + masked autoencoding (GPT, BERT)
    • 제거된 데이터를 복원하도록 학습하는 방식.
  • 이를 통하여 백억 개 이상의 param을 가지는 모델도 generalizable 하도록 학습시킬 수 있게됨.
  • Masked Auto-encoder는 비전에도 적용가능하나, 아직까지 많이 연구되지 못했음
    • 이에 저자들은 masked autoencoding이 vision과 language 간에 다른 점을 아래의 관점에서 분석해봄.
      1. 아키텍처의 차이:
        • 비전에서 CNN이 너무 강력했었음.
        • CNN은 grid에서 연산을 수행 → mask token, positional embedding을 적용하기 어려움.
        • 이제는 ViT가 있으니, 이는 문제가 되지 않음.
      2. 정보 밀도의 차이:
        • language → 인간이 만든 semantic signal → 정보 밀도가 매우 높음.
          • 이에 masked auto-encoder가 학습하려면 언어를 매우 잘 이해하고 있어야 함,
        • 그러나 vision은 매우 redundant한 signal → 누락된 patch를 이웃한 패치들만 보고도 쉽게 복원 가능함
        • 그래서 vision에서는, masked auto-encoder가 그냥 써서는 잘 안 먹힘
          • 그래서 저자들은 매우 높은 비율의 패치들을 random masking 함
          • 이는 유용한 특징을 잘 학습하게 만들어 비전에서도 masking이 잘 작동하게 하였음.
      3. 디코더의 역할 차이:
        • language에선 decoder가 누락된 단어를 예측함에 있어 상대적으로 풍부한 semantic을 학습할 수 있음.
        • 그러나 vision에서는 decoder가 latent를 픽셀로 복원함에 있어 상대적으로 entropy가 낮은 수준의 정보를 다룰 수 밖에 없다.
          • ex.
            • I ate ___. 에서 뭘 예측하냐에 따라 아예 달라짐.
            • 픽셀 값이 뭐가 예측되던간에 대세에 영향을 주긴 어려움
        • 그래서 NLP에선 decoder 설계에 크게 신경쓰지 않아도 되지만,
        • vision은 디코더 설계가 latent representation의 semantic 수준을 결정하는 데 매우 중요함
    • 이에 저자들은 visual representation learning을 위한 MAE 구조를 제안함.
      • input에서 랜덤하게 patch를 마스킹 후 이 패치를 픽셀 공간에서 복원
      • 이를 비대칭한 encoder-decoder 설계로 구성함.
        • encoder: mask token 없이 visible subset만 처리.
        • lightweight decoder: latent representation과 mask token을 모두 받아 전체를 복원.
  • asymmetric encoder–decoder → 계산량을 크게 줄여줌.
    • 75%를 마스킹한다하면 encoder는 25%만 학습하면 됨. → 학습시간 약 3배이상 줄여줌.
    • 그러면서도 정확도는 최적화.
  • high-capacity 모델을 일반화 가능하게 학습할 수 있음.
    • 데이터 요구량이 큰 모델들(ViT-Large/Huge)을 ImageNet-1K에서 Generalibility를 개선함.
    • 또한 detection, instance / semantic segmentation 과 같은 downstream task에서도 supervised pre-training보다 성능이 높았음을 확인함.

Related Works


Masked language modeling & Auto-regressive methods

BERT [14], GPT [47, 48, 4]: NLP에서 pre-training을 사용한 성공적인 methodology.

  • 입력 시퀀스의 일부를 숨기고, 제거된 내용을 예측하도록 모델을 학습.
  • 위의 연구를 통하여 pre-train한 representation은, 다양한 downstream task로 일반화됨이 증명됨.
    • 즉, 확장성이 높다.

Autoencoding

  • representation learning을 위한 고전적 방법
    • encoder: 입력을 latent representation으로 mapping encoder
    • decoder: 입력을 재구성.
      • PCAk-means도 autoencoder의 일종으로 볼 수 있다.
    • Denoising autoencoders (DAE) [58]: 손상된 입력 신호를 원본으로 재구성하도록 학습.
      • 역시 autoencoder의 한 클래스이다.
      • 픽셀 마스킹 [59, 46] 또는 색 채널 제거 [70] 등 다양한 corruption 사용 가능.
      • MAE도 denoising autoencoding의 한 형태이지만, 표준 DAE와는 거리가 멀다.

Masked image encoding methods

  • masking으로 손상된 이미지로부터 representation learning을 유도하는 방법
    • 초기 [59]에는 DAE에서 제안된 noise를 일으키는 방법 중 하나로 제안됨.
    • 자매품 Context Encoder [46]: CNN으로 큰 mask 영역을 inpaint하면서 학습
  • 최근 방법들은 [6, 16, 2]은 Transformers [57] 기반이다. (NLP에서 masking 방식이 성공하면서부터)
    • iGPT [6]: pixel sequence input에서 masking 된 부분을 예측하면서 학습.
    • ViT [16]: self-supervised 학습을 위해 masked patch prediction을 pretext task로 활용.
    • BEiT [2]: discrete token 예측

Self-supervised learning

  • vision에서 가장 많이 관심받은 representation learning 방법.
  • pre-training을 위한 다양한 pretext task들이 제안됨 [15, 61, 42, 70, 45, 17].
  • 최근에는 contrastive learning [3, 22]이 가장 큰 관심을 받았음.
    • 이미지 간의 similarity/dissimilarity 모델링에 기반함
    • data augmentation에 크게 의존한다는 단점 존재함 [7, 21, 8].
  • 반면에 auto-encoder는 개념적으로 다른 방향임.

Approach


  • MAE는 masking된 입력으로부터 원본 신호를 재구성하는 단순한 autoencoding임.,
    • 관찰된 신호를 latent representation으로 매핑하는 encoder
    • 그 latent representation으로부터 원본 신호를 재구성하는 decoder로 구성
  • 대신, 고전적인 방식과 달리 asymmetric design으로 구성함.
    • encoder는 non-masked 패치에 대해서만 연산.
    • lightweight decoder는 latent representation과 mask token들로부터 전체 신호를 재구성. image.png
  • Masking
    • ViT [16]와 동일하게 먼저 이미지를 규칙적인 non-overlapping patch들로 쪼갬.
    • 그 다음, patch들의 subset을 샘플링하여 masking.
      • replacement 없이 random patches를 unifrom distribution으로 샘플링
    • 높은 masking 비율을 사용하여, redundancy를 크게 없앰.
      • visible neighboring patch들만으론 쉽게 prediction 하기 어렵게 만듦
    • Why Unifrom distribution?
      • center bias (중심 편향)를 방지하기 위해.
        • 즉, 이미지 중심부에 더 많은 mask가 쌓이는 문제를 피함.
  • MAE Encoder
    • encoder는 ViT [16]이지만 visible, unmasked patch들만 처리함.
    • 기존 ViT과 마찬가지로 아래의 순서대로 처리됨.
      • linear projection→patch embedding → positional embedding → Transformer block
    • mask token들은 사용하지 않음.
  • MAE Decoder
    • decoder는 전체 토큰 집합을 모두 사용한다.
      1. encoded visible patch
      2. mask token
      • missing patch임을 표시하는 하나의 shared learned vector.
    • 모든 token에는 positional embedding을 더한다.
      • mask tokens의 위치정보를 유지하기 위함.
    • 이후엔 transformer block을 쌓아 구성.
    • pre-training 동안에만 전체 이미지 reconstruction task를 수행하는 데 사용됨
      • downstream에선 버려짐.
      • encoder 설계와 별개로 설계할 수 있음.
        • 매우 작은 decoder를 만들어서 실험함. (encoder보다 너비나 깊이가 작도록)
          • encoder 대비 token 당 <10%의 계산량 수준으로.
  • Reconstruction Target
    • 각 masked patch에 대해 pixel 값을 예측하여 input을 재구성해야 함.
      • Decoder 출력 element는 patch를 나타내는 pixel vector.
    • 이를 위해 decoder의 마지막 레이어엔 linear layer projection 사용.
      • patch의 pixel 수와 동일한 출력 채널 수를 갖도록. 이후 reshape.
    • Loss로는 MSE 사용. (Pixel Space 에서 연산됨.)
      • masked patches에 대해서만 계산 (BERT와 같은 방법.)
    • normalized pixel values를 재구성하도록 학습도 해봄.
      • 각 patch의 모든 pixel에 대해 mean, std로 patch를 norm.
      • 이 세팅이 representation quality를 향상시킨다는 것을 실험적으로 확인.

0개의 댓글