scGPT: toward building a foundation model for single-cell multi-omics using generative AI

BITLAB·2025년 10월 31일

Cui, Haotian et al. "scGPT: toward building a foundation model for single-cell multi-omics using generative AI." Nature methods vol. 21,8 (2024): 1470-1480. doi:10.1038/s41592-024-02201-0

논문 링크

Abstract

  • word와 유전자의 유사성에 착안, NLP 분야에서 성공한 generative pretrained transformer 접근법을 Single cell 분야에 적용
  • 3300만 개 이상의 데이터를 기반으로 파운데이션 모델 scGPT 구축
  • 학습된 scGPT는 유전자 발현 및 세포에 관한 생물학적 데이터를 효과적으로 추출
  • 파인 튜닝 등을 통한 방법으로 다양한 다운스트림 작업에서 높은 성능을 보임

Introduction

  • 문제점
    • 현재 single-cell 분석 모델들은 특정 분석 목적에만 초점을 맞춰 분산되어 있음
    • 연구별 데이터셋의 범위와 규모가 제한적임
  • 파운데이션 모델의 필요성
    • 기존 파운데이션 모델(GPT-4, Enformer)들은 일반적 지식 능력을 가짐
    • 다양한 조직의 유전자 간 복잡한 상호작용을 학습할 대규모 모델이 필요
  • scGPT
    • 3,300만 개 이상 세포로 사전학습된 파운데이션 모델
    • non-sequential한 오믹스 데이터를 직접 모델링하는 통합 생성형 사전학습 워크플로우 구축
    • 사전 훈련-파인튜닝 접근 방식 제시
    • 장점
      1. 전이 학습 / 파인튜닝 성능 : 세포 유형 예측, perturbation 예측, 배치 보정, 다중 오믹스 통합 등 광범위한 다운스트림 작업에서 성능 달성
      2. 생물학적 지식 반영 : 유전자 임베딩과 어텐션을 통해 조건별 유전자-유전자 상호작용 규명

Method

Input Embeddings

모델의 입력을 구성하기 위해, scGPT는 유전자, 발현값, 조건이라는 세 가지 이질적인 정보를 통합하여 고차원 벡터 공간에 투영함

  • Gene Tokens :

    • 각 유전자 g는 자연어 처리의 단어와 같이 이산적인 토큰으로 간주 → 고유한 정수 ID에 매핑
    • 세포 i의 유전자 토큰 시퀀스는 아래와 같이 표현
      g(i)=[id(g1(i)),,id(gM(i))]\mathbf{g}^{(i)} = [id(g_1^{(i)}), \dots, id(g_M^{(i)})]
  • Expression Values :

    • 단일 세포 데이터는 실험 배치 간의 기술적 변이로 인해 절대적인 발현량의 스케일 차이가 큼

    • scGPT는 이를 정규화하기 위해 cell-specific value binning 기법을 도입

      1. 각 세포 i 내에서 0이 아닌 발현값을 가진 유전자들을 B개의 동일한 크기의 구간으로 나누기
      2. 유전자 j의 발현값 xi,jx_{i, j}는 자신이 속한 구간의 인덱스 k로 변환 → 절대적 수치를 상대적 순위 정보로 변환하는 효과를 가짐
        xj(i)={k,if Xi,j>0 and Xi,j[bk,bk+1]0,if Xi,j=0x_j^{(i)} = \begin{cases} k, & \text{if } X_{i,j} > 0 \text{ and } X_{i,j} \in [b_k, b_{k+1}] \\ 0, & \text{if } X_{i,j} = 0 \end{cases}
    • 이를 통해 발현 벡터 x가 생성됨

  • Condition Tokens:

    • perturbation 여부와 같은 위치별 메타 데이터를 나타내는 토큰 벡터

      tc(i)=[tc,1(i),,tc,M(i)]\mathbf{t}_c^{(i)} = [t_{c,1}^{(i)}, \dots, t_{c,M}^{(i)}]

  • 최종 입력 임베딩 (Final Input Embedding):

    h(i)=embg(g(i))+embx(x(i))+embc(tc(i))\mathbf{h}^{(i)} = \text{emb}_g(\mathbf{g}^{(i)}) + \text{emb}_x(\mathbf{x}^{(i)}) + \text{emb}_c(\mathbf{t}_c^{(i)})

    • 세 가지 구성 요소는 각각 별도의 임베딩 레이어를 통과한 후, 원소별 덧셈을 통해 M x D차원의 최종 입력 h 를 형성
    • emb(g)와 emb(c)는 표준적인 고정 길이 임베딩 벡터 레이어
    • emb(x)는 binning된 값의 순서 관계를 모델링하기 위해 완전 연결 레이어를 사용

Cell and gene expression modeling by transformers

scGPT는 표준 트랜스포머 인코더 스택을 활용하여 유전자 간의 복잡한 상호작용을 포착할 수 있음

ht(i)=transformer_block(ht1(i)),t[1,n].\mathbf{h}_t^{(i)} = \mathrm{transformer\_block}\left(\mathbf{h}_{t-1}^{(i)}\right), \quad \forall t \in [1, n].

  1. 입력 임베딩 h는 L개의 트랜스포머 블록 스택을 통과
  2. 최종 출력은 유전자 수준의 풍부한 문맥적 표현을 포함
  • 세포 표현 (Cell Representation)

    • BERT와 유사하게, 유전자 토큰 시퀀스 맨 앞에 특수 토큰 cls를 추가합니다
    • 이 토큰에 해당하는 최종 출력 벡터는 세포 전체의 상태를 요약하는 aggregated representation으로 사용
  • 배치 및 Modality 처리

  • 서로 다른 실험이나 데이터 유형(RNA, ATAC-seq 등)을 구별하기 위한 추가 토큰 : (배치 : tbt_b / Modality : tmt_m)

  • 토큰들은 트랜스포머 내부로 입력되지 않고, 트랜스포머를 통과한 output에 연결

    → 배치나 Modality 자체의 특성이 아닌, 생물학적 정보에 더 집중하도록 유도

    → 파인튜닝에서 배치 효과를 보정하는 데 도움

  • 다중 오믹스 연결
    hout(i)=concat(hL(i),embb(tb(i))+embm(tm(i)))\mathbf{h}_{\text{out}}^{(i)} = \text{concat}(\mathbf{h}_L^{(i)}, \text{emb}_b(\mathbf{t}_b^{(i)}) + \text{emb}_m(\mathbf{t}_m^{(i)}))

  • scRNA-seq 세포 레벨 통합
    hc,out(i)=concat(hc(i),embb(tb(i)))\mathbf{h}_{c, \text{out}}^{(i)} = \text{concat}(\mathbf{h}_c^{(i)}, \text{emb}_b(t_b^{(i)}))

Generative pretraining

GPT 모델은 입력 토큰을 바탕으로 다음 토큰을 예측 → 세포생물학적 접근으로 변환

⇒ cell type을 통해 유전자 발현량을 예측 (세포는 유전자로 정의)

비순차적(non-sequential) 데이터인 유전자 발현을 위한 생성형 학습 방식을 제안

  • Specialized Attention Mask

    • 기존 causal masking은 순서가 없는 유전자에 적용 불가

      → scGPT는 입력 유전자를 known genes와 unknown genes로 동적 분할함

    • 표준 attention 공식 :
      Attention(Q,K,V)=softmax(QKTd+Amask)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + A_{\text{mask}}\right)V

    • mask 구성

      aij={0,if j unknown genes,0,if i=j and j unknown genes,,if ij and j unknown genes.a_{ij} = \begin{cases} 0, & \text{if } j \notin \text{ unknown genes}, \\ 0, & \text{if } i = j \text{ and } j \in \text{ unknown genes}, \\ -\infty, & \text{if } i \ne j \text{ and } j \in \text{ unknown genes}. \end{cases}

      • 알려지지 않은 유전자가 오직 알려진 유전자와 cls 토큰, 그리고 자기 자신에게만 어텐션을 수행하도록 강제 ⇒ 모델은 예측해야 할 유전자(= unknown gene)가 다른 알려지지 않은 유전자에는 어텐션하지 않고, 세포 임베딩이나 이미 알려진 유전자, 그리고 자기 자신에게만 어텐션하도록 마스킹
      • prediction confidence를(Attention 산출물) 기반으로 예측 순서를 정의 → 가장 예측하기 쉬운 것부터 예측하고, 그것을 단서 삼아 다음 것을 예측하는 동적인 순서를 제작 가능
      • 추론 시, 가장 확신도 높은 예측부터 순차적으로 알려진 유전자 세트에 추가하는 반복적 생성 프로세스를 가능하게 하여 , 비순차적 데이터에 대한 autoregressive 생성을 구현
  • Pretraining Objective

    • 학습 목표는 알려지지 않은 유전자의 비닝된 발현값을 예측하는 것

    • 손실 함수는 예측값과 실제값 사이의 Mean Square Error
      L=1UunkjUunk(MLP ⁣(hn(i))xj(i))2\mathcal{L} = \frac{1}{|\mathcal{U}_{\text{unk}}|} \sum_{j \in \mathcal{U}_{\text{unk}}} \left( \mathrm{MLP}\!\left(\mathbf{h}_n^{(i)}\right) - x_j^{(i)} \right)^2

    • 훈련 중에는 gene-prompt 단계와 cell-prompt 단계가 순차적으로 진행

      • 각 단계에서 계산된 손실 값을 합산하여 모델 파라미터를 업데이트

      • Gene prompts → 유전자 발현 예측: 알려진 유전자 발현 값을 기반으로 알려지지 않은 유전자 발현 값을 생성하도록 모델을 훈련

      • Cell prompts → 전체 유전체 발현 예측 : 특정 세포 유형 조건을 입력하면 전체 유전체 발현 패턴을 생성하도록 훈련

        ⇒ 다양한 환경에서 유전자 및 세포 간의 복잡한 상호작용을 효과적으로 학습 가능

        ⇒ zero-shot(사전 훈련) 및 fine-tuning 환경 모두에서 문맥 인지 능력을 갖추도록 함

Fine-tuning Objectives

사전학습된 모델은 특정 다운스트림 작업에 맞게 다양한 목적 함수의 조합으로 미세 조정

  1. GEP (Gene Expression Prediction)

    • 입력 유전자의 일부를 무작위로 마스킹하고 예측
    • 유전자 간 공동 발현 관계 학습
      LGEP=1MmaskjMmask(x~j(i)xj(i))2\mathcal{L}_{\text{GEP}} = \frac{1}{|\mathcal{M}_{\text{mask}}|} \sum_{j \in \mathcal{M}_{\text{mask}}} \left( \tilde{x}_j^{(i)} - x_j^{(i)} \right)^2
  2. GEPC (GEP for Cell Modeling)

    • 유전자 발현을 세포 임베딩으로부터 예측
      LGEPC=1MmaskjMmask(x~j(i)xj(i))2\mathcal{L}_{\text{GEPC}} = \frac{1}{|\mathcal{M}_{\text{mask}}|} \sum_{j \in \mathcal{M}_{\text{mask}}} \left( \tilde{x}_j^{(i)} - x_j^{(i)} \right)^2
  3. ECS (Elastic Cell Similarity)

    • 세포 임베딩 간의 유사도가 특정 임계값을 넘도록 → 대조 학습 방식 손실
    • 유사한 세포끼리 가깝게 제작
      LECS=(sim ⁣(hc(i),hc(i))β)2\mathcal{L}_{\text{ECS}} = -\left( \mathrm{sim}\!\left(\mathbf{h}_c^{(i)}, \mathbf{h}_c^{(i')}\right) - \beta \right)^2
  4. DAR (Domain Adaptation via Reverse Backpropagation)

    • 배치 효과 제거를 위한 적대적 학습
      • hch_c로부터 배치를 예측하는 도메인 분류기 생성
        → 해당 분류기는 배치를 잘 맞추도록 학습
        → 메인 트랜스포머는 Gradient Reversal Layer를 통해 분류기가 배치를 못 맞추도록 학습 : cls 표현이 batch-invariant
  5. Cell Type Classification

    • 표준적인 지도 학습
    • hch_c에 MLP 분류기를 연결하여 세포 유형 레이블을 예측하고, 교차 엔트로피 손실을 최소화

Fine-tuning on downstream tasks

  • Cell type annotation
    • GT가 존재하는 참조 데이터셋을 사용해 모델 학습
    • 별도로 분리된 query set을 통해 성능 검증
    • 사전학습 파운데이션 모델과 참조 데이터셋 사이 공통적으로 존재하는 유전자만을 유지해 사용
    • 유전자 발현 값은 정규화 + 로그 변환 + binning을 거쳐 입력
    • CE를 최소화 하도록 학습
  • Perturbation response prediction
    • Highly Variable Genes 선별 + 발현 값 전처리
    • binning 값 대신 log 변환 발현 값을 input, target으로 사용 → 절대적인 발현량 예측 용이
    • 조건 토큰을 추가해 유전자가 교란 대상인지 모델에 전달
    • control cell을 input - perturbed cell을 target으로 사용하는 pair를 구성해 학습 → 대조군 전체의 유전자 발현량과 특정 유전자의 Perturbation token을 바탕으로 이후 세포 상태를 예측하도록 학습
  • Batch correction on integrating multiple scRNA-seq datasets
    • batch effects 교정을 목적으로 함→ 데이터 통합 시 배치 효과는 제거, 생물학적 다양성은 보존
    • 생물학적 정보 보존 : GEP / GEPC 사용
    • 배치 보정 : ECS / DAR / DSBN (Domain-Specific Batch Normalization) 동시에 최적화
  • Integrative representation learning for scMultiomic data
    • modality token을 사용
    • GEP와 GEPC 목적 함수로 최적화
    • DAR 목적 함수를 추가하여 배치 보정을 함께 수행
  • Gene regulatory network inference
    • 유전자 조절 네트워크(GRN)를 추론하는 두 가지 방법
      1. 유전자 임베딩 기반
        • 유전자 임베딩 벡터의 유사성 기반 네트워크
        • 사전학습된 모델의 유전자 임베딩으로부터 k-최근접 이웃을 기반으로 유전자 유사도 네트워크를 구축
        • 특정 데이터셋으로 파인튜닝된 모델의 유전자 임베딩을 사용하여 네트워크를 구축 ⇒ 유사도 그래프에 대해 Leiden 클러스터링을 수행 ⇒ 5개 이상의 유전자로 구성된 유전자 클러스터 추출 : 유전자 프로그램 (기능 유전자 집합)
      2. 어텐션 기반
        • Adamson dataset으로 파인튜닝
        1. control 세포와 perturbed 세포를 각각 모델에 입력 → 어텐션 맵 획득
        2. 마지막 어텐션 레이어의 헤드로부터 attention scores를 추출
        3. 정규화 수행 후 평균 내어 최종적인 aggregate attention map을 생성
        4. 관심 유전자 열 선택 후 score 비교 → 높은 순서대로 정렬 시 영향을 많이 받는 유전자가 만들어짐 : 유전자 조절 네트워크 추론 결과

Result

Single-cell transformer foundation model overview

  • scGPT
    • 대규모 cell atlas에 대한 pretraining과 특정 작업을 위한 fine-tuning이라는 두 단계로 구성된 파운데이션 모델
    • non-sequential 유전자 발현 데이터에 적용하기 위해 특별히 설계된 attention mask와 generative training pipeline 을 도입
    • 사전학습을 위해 CELLxGENE 컬렉션에서 수집한 3,300만 개 이상의 정상 인간 세포 데이터가 사용

scGPT improves the precision of cell type annotation

  • cell type annotation 작업에 미세조정한 결과
  • 3가지 데이터셋에서 높은 예측 정확도를 달성 → 건강한 세포로 학습시킨 모델이 MS 질병 상태의 세포를 정확히 예측 → 6가지 암 종류로 학습한 모델이 이전에 본 적 없는 3가지 암 종류의 세포 유형도 정확히 분류
  • scGPT는 TOSICA, scBERT에 비해 모든 평가지표에서 일관되게 우수한 성능을 보임
  • 췌장 데이터셋 클러스터링 결과 : GT와 Pred가 유사함

scGPT predicts unseen genetic perturbation response

scGPT가 genetic perturbation 실험의 결과를 예측하는 데 사용될 수 있음을 확인

셀프 어텐션 메커니즘을 통해 모델은 교란된 유전자와 다른 유전자들 간의 복잡한 상호작용을 학습 가능함

  • Prediction of unseen gene perturbations
    • perturbation 실험 데이터로 미세 조정 후 unseen gene의 perturbation 결과 예측
    • 3개의 Perturb-seq 데이터셋에서 perturbation 전후 발현 변화량 측정 → scGPT가 선형회귀 모델보다 5-20% 더 우수
    • 실험적으로 검증된 5%의 perturbations 조합 → in silico로 확장해 전체 조합 공간 반응 예측
  • In silico reverse perturbation prediction
    • 어떤 세포 상태를 얻기 위해 어떤 유전자를 조작해야 하는지를 찾는 역방향 예측 작업 → 치료 유전자 타겟을 찾는데 활용
    • Norman 데이터셋의 20개 유전자 하위 집합(210개 조합 가능) → 39개(18%)의 알려진 perturbation으로 학습 → scGPT는 세포 상태를 유발한 원인 유전자 조합을 Top-K 예측 내에서 성공적으로 찾음

scGPT enables multi-batch and multi-omic integration

  • Multi-batch scRNA-seq integration
    • 서로 다른 배치(다른 실험)에서 생성된 scRNA-seq 데이터를 통합하는 경우 → 생물학적 변이 보존 + 기술적 배치 효과 제거 필요
    • 마스킹된 유전자 발현을 복구하는 방식으로 미세조정 (SSL)
    • Seurat, Harmony 같은 기존 기술과 비교해 scGPT AvgBIO 점수가 5-10% 높음
  • Single-cell multi-omic integration
    • RNA, ATAC, 단백질 등 여러 모달리티의 정보를 통합하여 단일 세포 임베딩을 추출 가능
    • 모든 세포가 모든 모달리티를 가질 경우
      • BMMC (RNA+Protein)
        • 12개 배치, 9만 개 세포의 복잡한 데이터 에서 Seurat보다 9% 향상된 AvgBIO 점수를 기록
      • CD4 naive T 세포와 CD4+ activated T 세포 같은 미묘한 하위 유형을 분리 가능
    • 세포마다 모달리티가 다를 경우
      • ASAP human PBMC (RNA, ATAC, Protein이 4개 배치에 걸쳐 모자이크 형태로 존재) : scMoMatt과 비교 시, scGPT가 특히 B세포, 골수성 세포 등에서 더 우수한 배치 보정 성능

scGPT uncovers gene networks for specific cell states

  • 임베딩 기반 GRN
    • 사전학습 : 사전학습 모델의 임베딩만으로 HLA class I과 class II 유전자 그룹을 정확히 분리
    • 미세조정 : 면역 데이터셋으로 미세조정한 모델은 T세포 활성화, B세포 신호전달등 더 구체적인 면역 관련 유전자 네트워크를 포착
  • 어텐션 기반 GRN
    • 어텐션 맵은 단일 세포 수준의 문맥 특이적 상호작용을 포착
    • perturbation 전후의 어텐션 맵을 비교함으로써, 특정 유전자에 의해 가장 큰 영향을 받는 유전자를 추론 가능
      • DDIT3 유전자를 억제했을 때, scGPT가 어텐션 맵을 기반으로 선정한 상위 20개의 가장 영향받은 유전자는 실제 ChIP-Atlas 데이터베이스에 등록된 DDIT3의 표적 유전자와 100% 일치

Scaling and in-context effects in transfer learning

  • Scaling Effect : 다양한 크기의 데이터로 scGPT 모델들을 사전학습시킨 결과 , 사전학습 데이터의 양이 증가할수록 미세조정 후의 다운스트림 작업 성능이 일관되게 향상
  • In-context Effects : whole-human 데이터와 특정 장기 데이터로 사전학습한 모델의 차이 탐구
    • COVID-19 데이터에서 혈액, 폐 특화 모델이 뇌 특화 모델보다 우수함 → 사전학습 모델과 다운스트림 작업의 맥락을 맞추는 것이 중요
    • 범용적인 whole-human 모델도 사용 가능

Discussion

  • 언어 모델의 self-supervised pretraining 성공에 착안 → 복잡한 생물학적 상호작용을 규명하고자 single cell에 동일한 접근법을 적용
  • Transformer를 통해 유전자와 세포 임베딩을 동시에 학습 + 어텐션 메커니즘으로 single cell의 유전자 간 상호작용을 포착하여 해석 가능성을 확보
  • 사전학습 단계로만 새로운 데이터셋에서 세포 유형에 따른 의미 있는 클러스터링 가능 + 모델에 학습된 유전자 네트워크가 실험적으로 알려진 그룹과 일치
  • 파인튜닝 과정을 통해 특정 작업에 맞게 조정된 scGPT는 처음부터 학습된 모델보다 우수한 성능을 보임
  • 한계
    • 현 사전학습 모델은 배치 효과를 본질적으로 완화 불가
    • 명확한 GT 값이 부재하고 데이터 품질이 일관되지 않아 모델 평가가 복잡
  • 향후 계획
    • spatial omics 처럼 더 크고 다양한 질병 데이터를 포함하는 데이터셋으로 확장
    • 시간 데이터를 통합해 인과 관계 학습
    • 파인튜닝 없이 다운스트림 task에 적용하는 문맥 내 지시 학습 탐구 → 분석의 차이와 요구사항을 정확히 파악 : 유용성, 확장성 증가

Comment

유전 정보를 범용 모델 구조를 이용해서 학습했고, 일정 성능을 보였다는 게 유의미
attention 매커니즘이 성공적으로 정보 추출을 진행하여 실험적으로 해야 했던 부분들 상호보완 가능성 있음
내부적인 문맥을 고려하는 방법 이외 지식 내재화를 통하여 일반화 시킬 수 있는 방법이 있는지 궁금

profile
AI Insight with Bitlab

0개의 댓글