-
안녕하세요 조쉬입니다.
오늘 리뷰할 내용은 정형 데이터를 딥러닝으로 풀이하는 SOTA 알고리즘인 TabNet입니다.
보통 딥러닝 네트워크가 활용되는 데이터들은 주로 텍스트, 이미지, 음성 등 비정형 데이터입니다.
NLP, OCR 등의 task를 생각하면 바로 떠오르는 딥러닝 아키텍처들이 있을 것이라 생각합니다. 하지만 정형 데이터를 다룰 때 네트워크를 사용한다면 어떻게 대응을 해야 하는가... 는 쉽게 생각이 나지 않으실 것 같습니다. 정형 데이터는 보통 lightgbm, xgboost 등 decision tree 기반의 아키텍처를 사용하여 진행하는 것이 일반적이기 때문일 것입니다. 하지만 최근에 나온 TabNet은 decision tree-based gradient boosting을 녹여낸 네트워크 모델로 기존의 트리 모델에서는 적용하기 힘든 다른 여러 테스크도 가능하게 만들어줍니다.
개요
우선 TabNet의 장단점부터 빠르게 알아보겠습니다.
- 장점
- 정형 데이터 단독 활용이 아닌 비정형 데이터와 결합 사용 가능
- shap & lime이 아닌 attention으로 XAI 가능
- Domain adaptation, Generative modeling, Semi-supervised learning 적용 가능
- 단점
- 비 직관적 변수 중요도 산출
- 정형 데이터 특성상 sparsity와 결측 값 그리고 non-sequential 구조로 인해 트리 모델이 manifold 학습에 적합한 구조임
Architecture Flow
해당 논문에서 사용된 예시 데이터는 미국 임금 설문조사 데이터입니다. 저자가 제시하는 sparse feature selection은 각 선택된 변수들 중 가장 영향력이 높은 변수들을 판단하기 때문에 모델의 해석력과 학습 능력을 향상한다고 주장합니다. 상기 이미지와 같이 변수를 선택 & 처리하는 decision block을 반복하며 모델의 학습이 진행되는 구조가 TabNet의 핵심입니다. 여러 층의 attention을 종합하여 최종적인 결정을 하는 bert와 사상적인 측면에서는 공유하는 부분이 있다고 생각합니다.
Encoder Architecture
인코더는 attentive transformer와 feature mask로 이루어져 있습니다. Split block은 feature transformer에서 처리된 값들을 attentive transformer에서 전체적인 산출물에 사용될 수 있도록 분리하는 역할을 합니다. 각 스탭에서 feature selection mask는 모델 기능 해석을 위한 정보를 축적하고 결합되어 전체적인 변수 중요도를 얻게 해 줍니다. 딥러닝의 구조를 활용하여 연속적인 모델 구조를 가져감과 동시에 변수를 제한적으로 활용한 결과들을 축적해 나가며 attention mask를 생성한다는 점에서 트리 구조의 모델과 비슷하면서 다르게 활용하는 양상이 인상적이었습니다.
Decoder Architecture
디코더 부분은 feature transformer를 fully connected layer로 이어 붙인 후에 변수들에 대한 정보를 다시 종합하는 방식을 취합니다.
Feature Transformer Architecture
해당 구조는 FC - BN - GLU를 하나의 단위로 4번 반복하는 구조를 가집니다. 4번의 반복 구조중 초반 2개의 구조는 모든 decision step에서 공유가 되며, 나머지 2개는 앞선 decision step에 의존적입니다.
GLU는 Gated Linear Unit의 약자로 인풋 벡터를 2개로 복사하여 하나는 sigmoid activation을 하고 나머지는 residual connection을 하여 최종적으로 두 벡터를 elementwise multiplication 하는 구조입니다.
Attentive Transformer Architecture
Learnable Mask를 활용하여 중요 변수들에 대한 정보를 학습하게 됩니다. Prior scale 정보는 현 decision step 이전에 feature 별 사용 빈도수를 집계한 정보입니다. 해당 단계의 Sparsemax가 coefficients의 normalization을 진행하며, sparse feature selection을 진행합니다.
사용법
해당 논문에서는 두 가지의 사용법을 제시합니다.
Unsupervised pre-training 방식과 Supervised fine-tuning 방식입니다. 제가 사용을 위해 코드 리뷰를 하던 당시에는 후자만 재현 코드가 남아있는 상태라 전자의 방식은 아직 사용해본 적이 없습니다. 현재는 모두 사용 가능한 것으로 보입니다.
Unsupervised pre-training는 인코더에서 생성된 embedding layer를 통하여 tabular feature를 재구성하는 decoder를 뒤에 붙여서 사용하는 방식입니다. 이미 학습된 embedding layer를 통하여 인풋 중 비어있는 부분을 masking 처리하여 embedding layer로 역으로 예측할 수 있도록 학습시키는 방식입니다.
Supervised fine-tuning은 기존의 tree 모형과 동일한 방식으로 사용이 진행됩니다.
Reference Code
GitHub - dreamquark-ai/tabnet: PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf
위 링크로 들어가시면 pytorch 기반의 TabNet을 확인할 수 있습니다. TabNet semi-supervised learn, classification, regression 사용 시 sklearn 파이프라인과 동일하게 fit, predict, predict_proba 함수 사용이 가능합니다. 추가적으로 내장 explain함수를 통한 attention의 산출도 가능합니다.
인풋(tabular data)에서 numerical feature들은 scaling을 하는 편이 수렴을 빨라지게 하며 category feature들은 label encoding을 하여 진행하여야 합니다. 추가적으로 category feature의 column index와 해당 column의 label이 몇 개인지 모델 개체 생성 시에 명시해야 합니다.
해당 논문을 쓴 저자들이 TabNet 구조를 활용하여 추가적으로 개발한 SOTA 알고리즘들이 추가적으로 있는 것이 확인이 되어 다음 리뷰할 내용은 Temporal Fusion Transformer입니다. 시계열과 정형 데이터의 융합 알고리즘으로 수요예측 같은 시계열적 데이터 대응에 용이할 것으로 생각됩니다.
다음 글에서 뵙겠습니다.
읽어주셔서 감사합니다 :)
🤖 참고자료
TabNet 리뷰! 드디어 table (정형 데이터)를 위한 효과적인 딥러닝(인공신경망) 모델이 나타났습니다
TabNet 리뷰! 드디어 table (정형 데이터)를 위한 효과적인 딥러닝(인공신경망) 모델이 나타났습니다! xgboost, lgbm, catboost과 같은 decision tree-based gradient boosting 기법들의 장점을 이용한 모델입니다..
lv99.tistory.com
[논문 리뷰 및 코드구현] TABNET
[Review] TABNET: Attentive Interpretable Tabular Learning (2019) 이번 포스팅은 Tabular(정형) 데이터에 적합한 딥러닝 모델이라 주장하는 TABNET 논문 리뷰를 해보겠습니다. 궁금한점이나 해석이 잘못된 부분..
wsshin.tistory.com
Welcome to the SHAP documentation - SHAP latest documentation
Welcome to the SHAP documentation — SHAP latest documentation
© Copyright 2018, Scott Lundberg. Revision 46b3800b.
shap.readthedocs.io
Papers with Code - Residual Connection Explained
Papers with Code - Residual Connection Explained
Residual Connections are a type of skip-connection that learn residual functions with reference to the layer inputs, instead of learning unreferenced functions. Formally, denoting the desired underlying mapping as $\mathcal{H}({x})$, we let the stacked non
paperswithcode.com
'NN' 카테고리의 다른 글
7. Graph Neural Network(GNN) (0) 2022.03.03 6. Temporal Fusion Transformers (0) 2021.12.30 4. CNN과 Computer Vision (0) 2021.11.14 3. Neural Network 기초-2 (개념과 구조) (0) 2021.10.09 2. Neural Network 기초-1 ( 데이터 정의 및 활성함수) (0) 2021.09.12