[논문 리뷰] DeepHit: A Deep Learning Approach to Survival Analysis With Competing Risks

나나·2025년 3월 10일

생존분석에 관해 공부하다가 읽은 논문에 관해 리뷰하고자 합니다.

Lee, C., Zame, W., Yoon, J., & van der Schaar, M. (2018). DeepHit: A Deep Learning Approach to Survival Analysis With Competing Risks. Proceedings of the AAAI Conference on Artificial Intelligence, 32(1). https://doi.org/10.1609/aaai.v32i1.11842

아직 논문을 읽고 정리하는 것에 익숙하지는 않지만, 꽤 흥미롭게 읽어서 글을 작성하기로 했습니다.

Abstract

DeepHit는 네트워크를 사용하여 생존 시간 분포를 직접 학습하는 방식으로,
기본 확률 프로세스에 대한 가정을 하지 않으며, 시간이 지남에 따라 공변량과 위험 사이의 관계가 변할 가능성을 허용합니다.
DeepHit은 기본 확률론적 프로세스의 형태에 대해 어떠한 가정도 하지 않기 때문에, 고정된 원인(예: 질병)에 대해서도 매개변수와 확률론적 프로세스의 형태가 모두 공변수에 따라 달라질 수 있는 가능성을 허용합니다.

Survival Analysis

Survival Data


이 데이터에서 우리가 궁금한 것은 P(k=k,  s=sx=x)P(k=k*,\;s=s*|x=x*) 입니다. 즉, 어떠한 원인이 주어질 때, 사망 원인과 생존 시간이 궁금한 것이죠.
한정된 데이터셋으로는 PP를 알 수 없으니 true probabilities인 P^\hat P를 유추하는 것이 테스크가 됩니다.

Model Description

network를 train하여 P^\hat P를 학습하는 것이 목표입니다.

DeepHit은 K cause-specific sub-networks와 shared sub-network로 구성된 multi-task network입니다.

  • network가 각 이벤트의 공동 분포를 학습하도록 하기 위해 단일 softmax 레이어를 DeepHit의 출력 레이어로 활용합니다.
  • 입력 공변량에서 sub-network로의 연결을 유지합니다.
  1. 공유 하위 네트워크는 임상 공변량 x를 입력으로 받아 K 개의 경쟁 이벤트에 공통적인 (잠재적) 표현을 포착하는 벡터 fs(x)f_s(x)를 출력으로 생성합니다.
  2. 각 원인별 하위 네트워크는 z = (fs(x)f_s(x), x) 쌍을 입력으로 받아 특정 원인 k의 첫 번째 적중 시간 확률에 해당하는 벡터 fck(z)f_{ck}(z)를 출력으로 생성합니다.
  3. 이러한 출력의 총합은 first hitting time과 이벤트에 대한 공동 확률 분포이므로 sub-network는 각 원인에 대한 first hitting time에 대한 분포를 동시에 학습합니다.
  4. softmax 계층의 출력은 공변량 xx를 가진 환자가 주어지면 출력 yk,  sy_{k,\;s}는 환자가 시간 s에서 이벤트 k를 경험할 P^(s,  kx)\hat P(s,\;k|x)입니다.

Loss Function

DeepHit을 train하기 위해 censored data를 다루기 위해 설계된 LTotalL_{Total}을 최소화해야 합니다.

LTotal=L1  +  L2L_{Total} = L_1\;+\;L_2 이며, L1L_1은 first hitting time과 이벤트에 대한 분포의 로그우도 함수이고, L2L_2는 cause-specific ranking loss funtion의 조합입니다.


L1L_1은 DeepHit이 first hitting time과 이벤트 분포의 일반적인 표현을 학습하도록 합니다.


L2L_2는 이벤트가 실제로 발생하는 시간을 통합하여 네트워크를 각 CIF에 맞게 미세 조정합니다.

Experiments

Experimental Setting

Evaluation을 위해 5-fold cross validation을 적용합니다.
DeepHit은 4개의 레이어로 구성된 네트워크입니다:
shared sub-network를 위한 완전 연결 레이어와 각 cause-specific sub-network를 위한 2개의 완전 연결 레이어, 그리고 출력 레이어인 softmax 레이어.
레이어 1, 2, 3의 경우 ReLu 활성화 함수를 사용했습니다.
네트워크는 다음을 통해 역전파를 통해 학습되었습니다.
이는 tensorflow 환경에서 구현되었습니다.

Conclusion

DeepHit는 신경망을 훈련시켜 생존 시간과 이벤트의 추정 공동 분포를 학습하는 동시에 생존 데이터에 내재된 올바른 검열 특성을 포착합니다.
생존 시간과 상대적 위험 모두를 활용하는 손실 함수를 사용하여 네트워크를 훈련합니다.
단일 위험과 경쟁 위험 모두의 상황에서 이전 모델보다 더 나은 성능을 보였습니다.


생존분석을 딥러닝과 연결시켜 더 나은 성능을 보인 것이 흥미로웠습니다.
기존 생존분석은 통계적이고 수학적인 영역이라고 생각했는데, 딥러닝으로 확장되어 활용되는 것을 보니, 이런 부분들을 활용해보고 싶다고 생각했습니다.

profile
데이터를 설명하지 않고, 선택을 바꿉니다.

0개의 댓글