생존분석에 관해 공부하다가 읽은 논문에 관해 리뷰하고자 합니다.
Lee, C., Zame, W., Yoon, J., & van der Schaar, M. (2018). DeepHit: A Deep Learning Approach to Survival Analysis With Competing Risks. Proceedings of the AAAI Conference on Artificial Intelligence, 32(1). https://doi.org/10.1609/aaai.v32i1.11842
아직 논문을 읽고 정리하는 것에 익숙하지는 않지만, 꽤 흥미롭게 읽어서 글을 작성하기로 했습니다.
DeepHit는 네트워크를 사용하여 생존 시간 분포를 직접 학습하는 방식으로,
기본 확률 프로세스에 대한 가정을 하지 않으며, 시간이 지남에 따라 공변량과 위험 사이의 관계가 변할 가능성을 허용합니다.
DeepHit은 기본 확률론적 프로세스의 형태에 대해 어떠한 가정도 하지 않기 때문에, 고정된 원인(예: 질병)에 대해서도 매개변수와 확률론적 프로세스의 형태가 모두 공변수에 따라 달라질 수 있는 가능성을 허용합니다.

이 데이터에서 우리가 궁금한 것은 입니다. 즉, 어떠한 원인이 주어질 때, 사망 원인과 생존 시간이 궁금한 것이죠.
한정된 데이터셋으로는 를 알 수 없으니 true probabilities인 를 유추하는 것이 테스크가 됩니다.
network를 train하여 를 학습하는 것이 목표입니다.

DeepHit은 K cause-specific sub-networks와 shared sub-network로 구성된 multi-task network입니다.
DeepHit을 train하기 위해 censored data를 다루기 위해 설계된 을 최소화해야 합니다.
이며, 은 first hitting time과 이벤트에 대한 분포의 로그우도 함수이고, 는 cause-specific ranking loss funtion의 조합입니다.

은 DeepHit이 first hitting time과 이벤트 분포의 일반적인 표현을 학습하도록 합니다.

는 이벤트가 실제로 발생하는 시간을 통합하여 네트워크를 각 CIF에 맞게 미세 조정합니다.

Evaluation을 위해 5-fold cross validation을 적용합니다.
DeepHit은 4개의 레이어로 구성된 네트워크입니다:
shared sub-network를 위한 완전 연결 레이어와 각 cause-specific sub-network를 위한 2개의 완전 연결 레이어, 그리고 출력 레이어인 softmax 레이어.
레이어 1, 2, 3의 경우 ReLu 활성화 함수를 사용했습니다.
네트워크는 다음을 통해 역전파를 통해 학습되었습니다.
이는 tensorflow 환경에서 구현되었습니다.

DeepHit는 신경망을 훈련시켜 생존 시간과 이벤트의 추정 공동 분포를 학습하는 동시에 생존 데이터에 내재된 올바른 검열 특성을 포착합니다.
생존 시간과 상대적 위험 모두를 활용하는 손실 함수를 사용하여 네트워크를 훈련합니다.
단일 위험과 경쟁 위험 모두의 상황에서 이전 모델보다 더 나은 성능을 보였습니다.
생존분석을 딥러닝과 연결시켜 더 나은 성능을 보인 것이 흥미로웠습니다.
기존 생존분석은 통계적이고 수학적인 영역이라고 생각했는데, 딥러닝으로 확장되어 활용되는 것을 보니, 이런 부분들을 활용해보고 싶다고 생각했습니다.