📝 이번 포스트는 "Segment Anything" 논문에 대해 알아보도록 하겠습니다.
최근 대규모 웹 데이터셋을 활용한 "foundation model"이 자연어 처리(NLP)에서 Zero-Shot 및 Few-Shot 일반화 능력을 보이며 새로운 혁신을 가져왔습니다.
대표적으로 GPT 기반 언어 모델은 "prompt engineering"을 통해 다양한 작업에 적응하며, 사전 학습된 대규모 모델이 별도의 fine-tuning 없이도 새로운 데이터 분포와 작업을 처리할 수 있는 능력을 갖습니다.
컴퓨터 비전 분야에서도 CLIP과 ALIGN과 같은 모델이 텍스트와 이미지를 연결하는 방식을 통해 Zero-Shot 일반화를 실현했지만, 여전히 많은 컴퓨터 비전 과제에서는 충분한 학습 데이터가 부족한 상황입니다.
저자는 image segmentation을 위한 새로운 foundation model을 구축하고자 연구를 시작하였습니다.
이를 위해 프롬프트 기반 이미지 세그멘테이션(promptable segmentation)이라는 개념을 도입하여, 어떤 프롬프트가 주어지더라도 적절한 segmentation mask를 생성할 수 있는 모델을 설계하고, 이를 대규모 데이터셋으로 사전 학습하는 연구를 제시하였습니다.
SAM의 구체적인 내용을 설명하겠습니다.
이 연구는 자연어 처리(NLP)에서 사용하는 프롬프트 개념을 이미지 세그멘테이션으로 확장하였습니다.
프롬프트는 모델이 특정 객체를 분할할 수 있도록 제공하는 입력 정보로 사용됩니다.
프롬프트는 다음과 같은 형태로 정의됩니다.
주어진 프롬프트를 바탕으로 유효(valid)한 segmentation mask를 반환하는 것이 prompatble segmentation task의 핵심입니다.
유효(valid)한 segmentation mask는 프롬프트가 다의적으로 해석될 가능성이 있는 경우에도 적절한 객체를 세그멘테이션할 수 있도록 보장하는 것을 의미합니다.
위 그림 속 가장 두 번째 열과 같이 점을 찍었을 때, 그 점이 사람을 뜻하는지, 가방을 뜻하는지, 가방 속 파우치 부분을 뜻하는지 모호하게 해석될 수 있습니다.
이러한 경우에도 최소 한개의 마스크 최대 세 개의 마스크를 생성하도록 합니다.
Segment Anything Model (SAM)은 프롬프트 기반 세그멘테이션(promptable segmentation)을 수행하기 위해 설계된 모델입니다.
SAM은 image encoder, prompt encoder, mask encoder로 구성되어 있습니다.
SAM의 대략적인 개요는 아래와 같습니다.
SAM의 이미지 인코더는 입력 이미지를 저차원 임베딩으로 변환하는 역할을 합니다.
이 인코더는 임베딩을 한 번만 계산하며, 이후 프롬프트 입력 시에는 재사용되어 높은 연산량을 감당할 수 있도록 설계되었습니다.
MAE(Masked AutoEncoder)를 이용하여 사전 학습된 ViT/16(Vision Transformer, patch size=16) 모델을 이미지 인코더로 사용하였습니다.
SAM의 프롬프트 인코더는 다양한 유형의 프롬프트를 256차원 벡터 임베딩으로 변환하여, 이후 마스크 디코더가 이를 활용할 수 있도록 하는 역할을 합니다.
프롬프트 인코더는 sparse prompt와 dense prompt를 모두 처리할 수 있도록 설계되었습니다.
그리고 각 합성곱 사이에 GELU 활성화 함수와 layer normalization을 수행하였습니다.
마스크 디코더는 Transformer의 디코더 구조에서 영감을 받아 설계되었습니다.
저자가 수정한 디코더는 self-attention과 cross attention 기법을 활용하여 이미지와 프롬프트 정보를 결합합니다.
최종적으로 MLP 기반의 동적 선형 분류기를 사용하여 각 이미지 위치별 foreground 확률을 계산하여 마스크를 생성합니다.
SAM은 단일 프롬프트 입력이 여러 개의 유효한 마스크에 대응할 수 있는 모호성(Ambiguity) 문제를 해결하기 위해 다중 마스크 예측 방식을 도입하였습니다.
SAM은 기본적으로 단일 마스크 대신 whole, part, subpart 세 개의 마스크를 예측합니다.
학습 시 training loss는 각각 예측된 마스크들과 ground truth 마스크 간의 손실함수를 계산합니다.
하지만 역전파는 가장 낮은 loss를 갖는 마스크만 수행합니다.
예측된 마스크를 정렬하기 위해 추가적인 출력 토큰을 활용한 작은 헤드를 추가하였습니다.
이 헤드는 각 마스크가 객체를 얼마나 잘 커버하는지 나타내는 IoU 값을 예측합니다.
이를 통해 최적의 마스크를 선택하는 기준을 제공합니다.
모호성 문제는 단일 프롬프트에서 주로 발생하며, 여러 개의 프롬프트가 제공될 경우 거의 발생하지 않았습니다.
따라서 다중 프롬프트가 주어졌을 때는 기본적으로 단일 마스크만 예측하도록 설계하였습니다.
이것은 네 번째 출력 토큰을 추가하여 다중 프롬프트 입력 시 별도의 단일 마스크를 예측하도록 하였습니다.
이 네 번째 마스크는 단일 프롬프트에서는 출력되지 않고, 다중 프롬프트 시에만 출력됩니다.
실제 segmentation mask 데이터셋은 풍부한 양을 갖고 있지 않습니다.
연구진은 1.1B 개의 mask dataset을 가진 SA-1B를 확보하기 위해 데이터 엔진을 구축하였습니다.
이 데이터 엔진은 세 가지 단계로 구성되며, 각 단계에서 SAM 모델을 점진적으로 개선하여 보다 많은 마스크를 자동으로 생성하는 방식을 수행하였습니다.
첫 번째 단계로, interactive segmentation과 유사하며 사람(annotator)의 도움이 필요한 단계입니다.
사람이 SAM모델을 활용하여 마스크를 수동으로 주석을 다는 과정입니다.
사람들은 브라우저 기반의 interactive segmentation tool에서 background와 foreground를 점으로 클릭하여 객체를 라벨링 하였으며, 세밀한 조정을 위해 브러쉬 및 지우개 도구를 사용도 하였습니다.
두 번째 단계로, 첫 번째 단계에서 수집된 마스크를 활용하여, bounding box detector를 학습합니다.
Faster R-CNN을 bounding box detector로 사용하였습니다.
이 detector가 객체가 있을 가능성이 높은 위치를 자동으로 detect하고 SAM이 이를 바탕으로 마스크를 생성합니다.
이 단계에서는 사람의 도움이 첫 번째 단계에 비해 조금만 필요합니다.
사람이 자동으로 생성된 마스크를 검토하고 추가로 라벨링되지 않은 객체를 찾아 마스크를 만듭니다.
마지막 단계로, SAM이 모든 마스크를 자동으로 생성하는 방식으로 진행됩니다.
32x32 격자 형태로 점을 배치하여, 각 점에서 예측할 수 있는 모든 마스크를 생성합니다.
IoU 예측 모듈을 활용하여 가장 신뢰도가 높은 마스크를 선택하고, 중복되지 않고 신뢰도 높은 마스크만 유지합니다.
그리고 NMS(Non-Maximum Suppression) 알고리즘을 적용하여 중복된 마스크를 제거하여 신뢰도가 높은 마스크만 유지하도록 하였습니다.
1단계 Assisted-manual stage에서는 120k 개의 이미지로 4.3M개의 마스크를 생성하였고,
2단계 Semi-automatic stage에서는 180k 개의 이미지로 5.9M개의 마스크를 생성하였고,
3단계 Fully automatic stage에서는 11M 개의 이미지로 1.1B개의 마스크를 생성하였습니다.
이러한 데이터 엔진을 통해 기존 어떤 세그멘테이션 데이터셋보다도 400배 많은 고해상도 마스크가 포함된 SA-1B 데이터셋을 만들 수 있었습니다.
연구진은 SAM을 이미지 세그멘테이션을 foundation model의 패러다임으로 확장하려 하였습니다.
연구진은 프롬프트 기반 세그멘테이션이라는 새로운 개념을 도입하였고, 이를 지원하는 범용적인 세그멘테이션 모델과 대규모 데이터셋인 SA-1B를 구축하였습니다.