[논문 분석] TABR: TABULAR DEEP LEARNING MEETS NEAREST NEIGHBORS (ICLR 2024)

BING·2024년 8월 18일

[ 논문 분석 ]

목록 보기
4/16

요약

기존에는 tabular 데이터를 다룰 때 대부분 Gradient-Boosted Decision Trees (GBDT) 방법을 사용했지만, 최근 tabular 데이터를 위한 딥러닝 기술의 발전으로 DL(Deep Learning) 기반의 방법이 GBDT 방법과 유사하거나 더 좋은 성능을 보여주고 있음
이는 검색-증강(retrieval-augmented) 모델을 활용하여 보다 더 좋은 예측을 할 수 있게 되었기 때문.
여기서 검색 기반(retrieval-based) 방식이란 자연어 처리에서 주로 사용되는 방법인데, 특정 질문에 대한 답변(candidates)을 여러가지 뽑은 뒤에 그 중에서 가장 답변에 가까운 것을 뽑아 답변으로 채택하는 방식을 말함.
이를 Tabular 데이터셋에 적용한다면, target과 유사한 여러 샘플을 뽑아서 그 중에서 prediction을 하는 형식이라고 이해를 할 수 있음
검색 기반 방법을 사용하면 예측 성능을 향상시킬 수 있는 장점도 있지만, 부가적으로 purely parametric 방법(retrieval-free)에 비해 incremental learning이나 robustness에서의 성능 향상도 얻을 수 있다는 장점이 있음
추가로 GBDT 방법 대신에 DL 방법을 사용하면 domain adaptation, semi-supervised learning 등의 강력한 DL 방법을 적용하여 사용할 수 있는 장점이 있음

해당 논문에서는 새롭게 제안하는 TabR(Tabular Retrieval) 방법을 통해 여러 Tabular dataset 벤치마크에서 SOTA의 성능을 보여 줌.

[Background]

  • TabR는 Tabular Data에 대한 Deep Learning 모델 중 하나로, 이를 기반으로 한 연구들이 주목받고 있음.
  • 그러나 Gradient-Boosted Decision Trees(GBDT)가 여전히 탁월한 성능을 보여주고 있으며, 이를 능가하기 위한 연구들이 진행되고 있음. 특히, Retrieval-Augmented Models가 이러한 문제를 해결하기 위해 개발됨.
  • TabR는 이러한 모델 중 하나로, k-Nearest-Neighbors(KNN) 방식의 요소를 결합한 Feed-Forward 네트워크임.

[Motivation]

  • Tabular Deep Learning 모델들이 GBDT와 경쟁할 수 있도록 성능을 개선하고, 효율성을 높이기 위한 동기에서 연구가 시작됨.

[Contribution]

  • Tabular 데이터에 대한 새로운 Retrieval-Augmented 모델 제안.
  • 기존의 GBDT 모델과 비교하여 새로운 벤치마크에서 우수한 성능을 보여줌.
  • 기존의 Retrieval-Augmented 모델에서 사용되던 Attention mechanism을 수정하여 간단하고 효율적인 구조를 가짐.

[Research Flow]

  • 이 연구는 Tabular 데이터에 대한 Deep Learning 모델의 성능을 향상시키기 위해 Retrieval 요소를 결합한 모델을 제안하는 흐름에서 시작됨.
  • 기존 모델들과의 성능 비교를 통해 이 연구로 도달하게 되었음.

[ Related Work ]

Parametric deep learning models

Parametric tabular 딥러닝 모델은 딥러닝의 장점을 실제 tabular 데이터에 응용하고자 하는 연구임.최근 연구에서는 MLP와 유사한 backbone을 사용한 딥러닝 기반의 방법과, Transformer와 유사한 backbone을 가진 딥러닝 기반의 방법이 좋은 성능을 보여. 또한 continuous features에 대해서 임베딩 하는 새로운 방법 등을 적용함으로써 GBDT와 tabular DL 방법의 간극을 크게 줄이는데 성공함.

Retrieval-augmented models in general

일반적으로 검색 기반 방법은 다음과 같은 순서로 진행함.
1) 데이터셋에서 input object와 관련성이 있는 샘플들을 검색
2) input object와 뽑힌 샘플(input과 관련있는)들을 함께 가공하여 input object에 대한 더 좋은 최종 예측을 가능하게 함.
가장 흔하게 사용되는 검색 기반 방식은 local learning paradigm이 있고, 가장 쉽게 적용할 수 있는 모델 중 하나는 k-nearest neighbors (kNN) 알고리즘이 있음

Retrieval-augmented models for tabular data problems

기존에 존재하는 검색-증강 기반 DL모델들은 성능이 simple-parametric 방법에 비해 조금 좋은 정도에 그.그리고 무거운 Transformer 같은 구조를 사용하면서 계산량이 증가하게 됨. 해당 논문에서 제시하는 방법은 이전의 연구들에 비해 단 한개의 single-head attention 모듈을 사용하고, 해당 모듈을 cutomize 하여 기존의 방법들을 뛰어넘는 성능을 보여줌.

[Proposed Method]

Preliminaries

본 논문에서는 binary classification, multiclass classification, regression 이라는 3가지 태스크에 대해 고려함
데이터셋은 train, test, valid 의 세 영역으로 나누었으며, validation 파트는 early stopping 과 하이퍼 파라미터 튜닝에 사용함
검색 기반의 방법을 사용할 때 사용되는 "candidates"는 훈련 데이터셋으로 부터 얻음

실험 시에는 적절한 앙상블 성능을 얻기 위해서 15개의 random seed를 사용해 겹치지 않는 3개의 그룹으로 나누어 각 그룹의 평균 성능을 계산하고, 3개의 그룹의 평균 성능으로 결과를 표시

Overall Architecture

  • TabR의 기본 구조는 Feed-Forward 네트워크로, 중간에 Retrieval 모듈이 추가된 형태(residual branch로 더해주어 구현)
    • Retrieval 모듈은 k-Nearest Neighbors(KNN) 방식과 유사하게 동작하며, 입력된 대상 객체와 유사한 후보 객체들을 찾아 이들의 정보를 이용해 예측을 보강함.
  • 네트워크는 크게 다음 세 부분으로 나뉨:
    • 1) Encoder (E):
      • 입력 객체의 특징을 인코딩하여 중간 표현을 생성함. 이 중간 표현은 대상 객체와 후보 객체 모두에 대해 생성되며, 이후 Retrieval 모듈에서 사용됨.
    • 2) Retrieval Module (R):
      • 이 모듈은 대상 객체와 후보 객체 간의 유사도를 계산하여, 가장 유사한 m개의 후보 객체를 선택하고, 이들의 정보를 이용해 대상 객체의 중간 표현을 보강함.
    • 3) Predictor (P):
      • 최종 예측을 수행하는 모듈로, 보강된 중간 표현을 입력받아 결과를 생성함.

Encoder 와 Predictor는 각각 위의 그림과 같은 구조를 가짐. 본 논문에서는 encoder, predictor에 대해서는 크게 수정하지 않고 위의 구조를 그대로 사용함.
Encoder의 앞에 있는 Input Module 은 feature normalization이나 one-hot encoding 같은 input processing을 담당함

Retrieval Module의 동작 원리

본 논문에서는 k-nearest neighbor 방법을 응용한 retrieval module을 제안

  • Retrieval 모듈은 크게 1) Similarity Module2) Value Module로 나뉨.

1) Similarity Module:

  • TabR의 유사도 모듈은 기존의 Self-Attention 방식에서 사용되는 dot product 대신, L2 거리를 이용하여 대상 객체(쿼리)와 후보 객체(키) 간의 유사도를 계산함.
  • 기존의 Self-Attention 방식은 query와 key 간의 dot product을 이용해 유사도를 계산하였으나, TabR에서는 쿼리를 제거하고 key만을 사용하여 단순화시켰음. 이를 통해 성능을 크게 향상시켰음.
  • 이 L2 거리 기반 유사도 모듈은 탐색된 후보 객체들이 대상 객체와 더 유사한 이웃들이 되도록 보장하며, 이는 모델 성능 향상에 중요한 역할을 함.

2) Value Module:

  • Value Module은 후보 객체의 라벨 정보와 대상 객체와 후보 객체 간의 차이를 기반으로 대상 객체의 중간 표현을 보강함.
  • Value Module은 두 가지 요소로 구성됨:
    • 1) 후보 객체의 label 정보를 추가: 후보 객체의 라벨을 임베딩하여 대상 객체의 예측에 기여함.
    • 2) 차이 기반 보정: target 객체의 표현을 가져와서 value module의 표현력을 더 증가 시키는 것이 목적임.
      대상 객체와 후보 객체 간의 차이를 계산하고, 이를 보정하는 추가적인 값을 예측에 반영함. 이를 위해 최근에 제안된 DNNR 방법의 kNN 알고리즘을 사용.
      • 이 보정 값은 대상 객체와 후보 객체의 차이를 더 잘 반영할 수 있도록 도와줌.
      • 이러한 값 모듈의 보정 과정은 특히 회귀 문제에서 유용하게 작용하며, 대상 객체와 후보 객체 간의 관계를 더 정확하게 반영할 수 있음.

Incremental Design 으로 개선된 모델

  • TabR는 단번에 완성된 모델이 아니라, 점진적으로 개선된 모델임. 모델 설계 과정에서 여러 단계에 걸쳐 성능을 분석하고 개선하였음:

    • Step-0: 기본적인 Self-Attention 방식으로 시작하여, 이 방식이 단순 MLP(Multi-Layer Perceptron)와 비슷한 성능을 보이는 것을 확인함.
    • Step-1: 후보 객체의 라벨 정보를 값 모듈에 추가하였으나, 성능 개선이 미미했음.
    • Step-2: 쿼리 제거 및 L2 거리 사용으로 similarity Module을 개선하였고, 이 단계에서 성능이 눈에 띄게 향상됨.
    • Step-3: Value Module을 더 발전시켜 대상 객체와 후보 객체 간의 차이를 반영하는 보정 값을 추가함으로써 성능을 더 향상시킴.
    • Step-4: 추가적인 기술적 개선을 통해 최종적으로 TabR 모델을 완성함.

Overall Review

  • TabR의 Retrieval 모듈이 중요 포인트 : L2 거리 기반의 유사도 모듈과 Label 정보를 활용한 Value 모듈로 구성되어 있음. 이를 통해 성능을 향상시킴.

[Experiments]

Dataset

  • 다양한 공개 데이터셋을 사용하여 실험을 진행하였으며, 중소 규모의 데이터셋을 주로 활용함.

평가 지표

  • 실험에서는 정확도(Accuracy)와 Root-Mean-Square Error(RMSE) 등을 평가 지표로 사용함.

Implementation Details

  • Retrieval 모듈에서 k-Nearest-Neighbors 방식을 사용하며, 후보 객체들 중에서 가장 유사한 객체들을 선택하여 모델의 성능을 향상시킴.
    Retrieval module의 개선점들을 각 단계마다 적용한 성능을 위의 표2에서 확인할 수 있음. 각 데이터셋 별로 이전 단계보다 좋은 성능을 보이는 단계에 밑줄이 그어져 있.

[Results]

  • 각 모듈의 효과를 분석한 결과, L2 거리 기반 유사도 모듈과 Label 정보가 모델 성능에 큰 영향을 미친다는 것을 확인함.
  • 새롭게 제안된 TabR, TabR-S(Simple) 모델이 1.0 대의 평균 rank를 보이며 좋은 성능을 보여줌

    DL 기반의 방법들 뿐만 아니라 Tabular dataset에서 주로 사용되었던 GBDT 방법들과의 성능 비교표를 위 표4에서 확인할 수 있음. 마찬가지로 TabR 방법이 GBDT 방법들보다 더 좋은 성능을 보여줌.

[Conclusion]

Summary

본 논문에서는 기존에 사용되던 검색 기반의 DL 방법을 가져와 Retrieval module(similarity module과 value module)을 개선하여 최적의 성능을 달성함.
Retrieval Candidate에 따른 최종 예측 성능 분석과 같은 새로운 분석의 기회도 제공함.
추후에는 검색 기반 방식의 효율성을 높이고, 수 백만개의 데이터도 처리할 수 있는 모델을 만드는 것을 future works로 제시

논문 출처 : https://arxiv.org/abs/2307.14338
Yury Gorishniy, Ivan Rubachev, Nikolay Kartashev, Daniil Shlenskii, Akim Kotelnikov, Artem Babenko. TabR : Tabular Deep Learning Meets Nearest Neighbors In 2023. ICLR 2024.

profile
[ SPS Lab Paper Seminar YouTube ] : https://www.youtube.com/@spslab.1648

0개의 댓글