구(phrases)로 구성된 concept의 개념을 도입(아래 그림의 fantastic actor, fabulous acting 등)
: 시퀀스의 구성요소
: 를 비말단 구성요소로 분리한 것 (는 의 비말단 요소의 개수)
: 분류 모델

이 그림은 감성 분석 예제에서 SELFEXPLAIN 모델이 생성한 해석을 보여줍니다.
입력 문장은 "The fantastic actors elevated the movie"이며, 모델의 예측은 긍정(positive)입니다.
Word Attributions:
SELFEXPLAIN:
: 시퀀스의 구성요소
: 를 사전학습된 트랜스포머에 임베딩한 벡터
: 를 비말단 구성요소로 분리한 것 (는 의 비말단 요소의 개수)
: concept의 임베딩 벡터로, 아래 수식에 의해 계산. 즉, 해당 구 내 모든 단어 임베딩을 벡터별로 더한 뒤, 단어 수로 나누어서 평균을 냄.
: 에 포함된 단어들. 즉, 가 "the good soup"라는 구라면, 이 안의 단어들은 "the", "good", "soup"입니다.
: 입력 문장 내 단어 의 transformer 최종 레이어의 임베딩 벡터.
: 해당 구 에 포함된 단어의 개수. 예시에서 "the good soup"라면 3이 됨.
: 해당 구 내 모든 단어 임베딩 벡터의 합
입력 문장: "The chef cooks the good soup"
구 "the good soup"
단어 임베딩 벡터 (크기 3이라고 가정):
단어 수:
벡터 합:
평균:
"the good soup"이라는 구는 벡터 로 표현되며, 이것이 모델 내에서 해당 구를 대표하는 개념 벡터로 쓰임.
: 신택스 트리의 루트 노드
: 루트 노드에 대한 벡터. [CLS] 토큰의 출력 벡터 로 나타냄 (이밖에 모든 벡터에 대한 mean pooling이나 sum pooling 등의 방법을 사용했을 때, 결과가 비슷했기에 [CLS] 토큰의 벡터를 사용)
: 에 대한 분류 확률 분포.
: 예측된 클래스 인덱스
SELFEXPLAIN은 NLP용 self-explaining 모델로서, 모델의 예측을 설명할 때 단어 단위가 아니라 구 구간(phrase) 단위의 개념(concepts)을 사용
LIL은 이러한 개념 각각이 최종 예측에 얼마나 기여하는지 local relevance score를 부여하여 정량화
이 기여도 계산은 컴퓨터 비전 분야에서 널리 쓰이는 활성화 차이(activation difference) 기법을 참고
활성화 차이란, 특정 입력 개념이 모델의 최종 출력에 미치는 영향을 그 개념을 포함했을 때와 제외했을 때 활성화값의 차이로 측정하는 것
즉, LIL은 입력 문장을 구성하는 주요 문법 단위인 구(phrases) 하나하나가 모델의 최종 예측에 얼마나 영향을 미쳤는지를 활성화 차이를 통해 계산해, 해석 가능한 설명을 제공합니다.
def lil(self, hidden_state, nt_idx_matrix):
# hidden_state: 문장 내 각 토큰의 임베딩 벡터.
# nt_idx_matrix: 문장 내 각 구문이 어떤 토큰으로 이루어져있는지에 대한 정보를 인덱스화 한 것.
phrase_level_hidden = torch.bmm(nt_idx_matrix, hidden_state)
# 전체 구문에 원한 인코딩 된 구문별 인덱스를 곱하여 해당 구문이 아닌 것을 0으로 만듦.
# hidden_state: [batch, seq_len, hidden_dim]
# nt_idx_matrix: [batch, num_phrases, seq_len]
# 결과: [batch, num_phrases, hidden_dim] (김, 철수) 는 각 벡터를 더해서 (김철수) 하나의 벡터로 생성됨.
phrase_level_activations = self.activation(phrase_level_hidden)
# 각 구문별 임베딩에 ReLU 활성화 함수를 적용
pooled_seq_rep = self.sequence_summary(hidden_state).unsqueeze(1)
# hidden_state: [batch_size, seq_len, hidden_dim]
# self.sequence_summary(): 맨 앞의 [CLS] 토큰만 풀링 ([batch_size, hidden_dim])
# unsqueeze(1): [batch_size, 1, hidden_dim]로 변환 (브로드 캐스팅을 위해)
# 이렇게 하면 phrase_level_activations의 차원인 [batch, num_phrases, hidden_dim]과 알아서 차연산이 가능해짐
phrase_level_activations = phrase_level_activations - pooled_seq_rep
# 문장 전체 벡터를 통해 생성한 확률 - 위에서 계산한 확률
phrase_level_logits = self.phrase_logits(phrase_level_activations)
# 이 값을 nn.Linear() 분류기를 통과.
# 논문의 수식과 달리 Softmax는 여기서 생략하고, 나중에 손실 함수에서 사용.
return phrase_level_logits
phrase_level_hidden = torch.bmm(nt_idx_matrix, hidden_state)
hidden_state (1, 5, 4)[
[
[0.1, 0.2, 0.3, 0.4], # 1번째 토큰
[0.5, 0.6, 0.7, 0.8], # 2번째 토큰
[0.9, 1.0, 1.1, 1.2], # 3번째 토큰
[1.3, 1.4, 1.5, 1.6], # 4번째 토큰
[1.7, 1.8, 1.9, 2.0], # 5번째 토큰
]
]
nt_idx_matrix (1, 3, 5)[
[
[1, 0, 0, 0, 0], # 1번째 구문: 1번째 토큰만 포함
[0, 1, 1, 0, 0], # 2번째 구문: 2,3번째 토큰 포함
[0, 0, 0, 1, 1], # 3번째 구문: 4,5번째 토큰 포함
]
]
[1, 0, 0, 0, 0] × hidden_state[0.1, 0.2, 0.3, 0.4][0, 1, 1, 0, 0] × hidden_state[0.5+0.9, 0.6+1.0, 0.7+1.1, 0.8+1.2][1.4, 1.6, 1.8, 2.0][0, 0, 0, 1, 1] × hidden_state[1.3+1.7, 1.4+1.8, 1.5+1.9, 1.6+2.0][3.0, 3.2, 3.4, 3.6]phrase_level_hidden (1, 3, 4)[
[
[0.1, 0.2, 0.3, 0.4], # 1번째 구문
[1.4, 1.6, 1.8, 2.0], # 2번째 구문
[3.0, 3.2, 3.4, 3.6], # 3번째 구문
]
]
phrase_level_activations = self.activation(phrase_level_hidden)
pooled_seq_rep = self.sequence_summary(hidden_state).unsqueeze(1)
phrase_level_activations = phrase_level_activations - pooled_seq_rep
phrase_level_logits = self.phrase_logits(phrase_level_activations)
SELFEXPLAIN 모델의 Global Interpretability Layer(GIL)은 주어진 입력 샘플에 대해 훈련 데이터 내에서 가장 영향력 있는 개념 K개를 찾아 설명합니다. 이를 통해 새로운 입력에 대한 모델의 결정에 훈련 데이터의 어떤 개념들이 중요한 역할을 했는지 글로벌한 관점에서 이해할 수 있습니다.
: 훈련 데이터에서 추출한 모든 개념들의 집합
모델이 학습을 진행하면서 임베딩 레이어 가 계속 업데이트되어 단어 및 개념 벡터 표현도 변합니다.
따라서, 훈련 데이터에서 추출된 각 개념의 임베딩 도 시간이 지나면서 달라집니다.
이 때문에 인덱싱된 저장소 Q의 벡터들을 일정 주기(예: 매 몇 백 혹은 몇 천 학습 스텝마다)마다 다시 계산하고, 이를 인덱싱하여 최신 상태로 유지합니다.
이렇게 하면 GIL 레이어가 항상 최신 임베딩을 고려한 영향력 있는 개념들을 정확하게 검색할 수 있습니다.
## GIL 레이어 ##
# 문장 전체 벡터와 훈련 데이터의 개념 벡터들 사이의 내적을 통해 상위 K개의 개념을 추출
# 이 개념들에 대한 분류 결과 로짓 값과 그들의 인덱스를 계산
def gil(self, pooled_input): # pooled_input = sentence_cls: 문장 전체를 대표하는 벡터
batch_size = pooled_input.size(0) # [batch_size, hidden_dim]의 size(0) = batch_size
inner_products = torch.mm(pooled_input, self.concept_store.T)
# self.concept_store = [num_concepts, hidden_dim]
# 문장 전체 벡터와 개념 벡터들 사이의 내적을 계산
# inner_products = [batch_size, num_concepts]
_, topk_indices = torch.topk(inner_products, k=self.topk)
# 유사도가 가장 높은 K개의 개념 인덱스 선택
topk_concepts = torch.index_select(self.concept_store, 0, topk_indices.view(-1))
# self.concept_store: [num_concepts, hidden_dim] (훈련 데이터에서 추출한 모든 개념(phrase) 벡터의 집합)
# topk_indices: [batch_size, topk] 각 배치(문장)마다 유사도가 가장 높은 K개 개념의 인덱스
# topk_indices.view(-1): [batch_size * topk]로 펼쳐서, 전체 배치의 K개 인덱스를 1차원으로 만듦
# torch.index_select(self.concept_store, 0, ...) concept_store에서 해당 인덱스에 해당하는 개념 벡터만 추출
# 결과 shape: [batch_size * topk, hidden_dim] 배치 내 모든 문장에 대해, 각 문장별로 topk 개념 벡터를 한 번에 뽑아냄
topk_concepts = topk_concepts.view(batch_size, self.topk, -1).contiguous()
# 위에서 뽑은 [batch_size * topk, hidden_dim] 텐서를 [batch_size, topk, hidden_dim]로 다시 reshape
# .contiguous()는 메모리 상에서 연속적인 배열로 만들어줌(reshape 후 연산 최적화)
# 각 배치(문장)마다 topk개의 개념 벡터를 3차원 텐서로 정렬
# topk_concepts: [batch_size, topk, hidden_dim]
concat_pooled_concepts = torch.cat([pooled_input.unsqueeze(1), topk_concepts], dim=1)
# 문장 전체 벡터와 상위 K개 개념 벡터를 결합
# concat_pooled_concepts: [batch_size, 1 + topk, hidden_dim]
attended_concepts, _ = self.multihead_attention(query=concat_pooled_concepts,
key=concat_pooled_concepts,
value=concat_pooled_concepts)
# 어텐션 멀티헤드 레이어를 통해 각 개념 벡터에 대한 가중치를 계산
# 이 가중치는 문장 전체 벡터와 개념 벡터 사이의 유사도를 나타냄
# 이 가중치를 통해 개념 벡터들을 조합하여 새로운 벡터를 생성
# attended_concepts: [batch_size, 1 + topk, hidden_dim]
gil_topk_logits = self.topk_gil_mlp(attended_concepts[:,0,:])
# [batch_size, 1 + topk, hidden_dim]에서 [:, 0, :]는 각 배치의 0번째 벡터(즉, 문장 전체 벡터, 어텐션을 거친 후의 값)만 추출
# 결과 shape: [batch_size, hidden_dim]
# 이 벡터를 기준으로 [batch_size, hidden_dim] -> [batch_size, num_classes]로 변환되는 nn.Linear() 클래시피케이션 레이어 통과
# print(gil_topk_logits.size())
# gil_logits = torch.mean(gil_topk_logits, dim=1)
return gil_topk_logits, topk_indices
pooled_input: [batch_size, hidden_dim] self.concept_store: [num_concepts, hidden_dim] inner_products = torch.mm(pooled_input, self.concept_store.T)
[batch_size, num_concepts]_, topk_indices = torch.topk(inner_products, k=self.topk)
topk_concepts = torch.index_select(self.concept_store, 0, topk_indices.view(-1))
topk_concepts = topk_concepts.view(batch_size, self.topk, -1).contiguous()
[batch_size, topk, hidden_dim]topk_concepts = torch.index_select(self.concept_store, 0, topk_indices.view(-1))
self.concept_store: [num_concepts, hidden_dim] topk_indices: [batch_size, topk] topk_indices.view(-1): [batch_size * topk]로 펼쳐서,torch.index_select(self.concept_store, 0, ...) [batch_size * topk, hidden_dim]topk_concepts = topk_concepts.view(batch_size, self.topk, -1).contiguous()
[batch_size * topk, hidden_dim] 텐서를[batch_size, topk, hidden_dim]로 다시 reshape.contiguous()는 메모리 상에서 연속적인 배열로 만들어줌(reshape 후 연산 최적화)예시
topk_indices 예시:[[5, 2, 7], # 1번째 문장: 5,2,7번째 개념
[1, 0, 3]] # 2번째 문장: 1,0,3번째 개념topk_indices.view(-1) → [5,2,7,1,0,3]torch.index_select(...) → [6, 4].view(2, 3, 4) → [2, 3, 4] 정리
[batch_size, topk, hidden_dim] 형태로 배치별로 정렬concat_pooled_concepts = torch.cat([pooled_input.unsqueeze(1), topk_concepts], dim=1)
attended_concepts, _ = self.multihead_attention(
query=concat_pooled_concepts,
key=concat_pooled_concepts,
value=concat_pooled_concepts
[batch_size, topk+1, hidden_dim]pooled_input:[
[0.1, 0.2, 0.3, 0.4], # 문장1
[0.5, 0.6, 0.7, 0.8], # 문장2
] # shape: [2, 4]topk_concepts:[
[ # 문장1
[1.1, 1.2, 1.3, 1.4], # 1번째 개념
[2.1, 2.2, 2.3, 2.4], # 2번째 개념
[3.1, 3.2, 3.3, 3.4], # 3번째 개념
],
[ # 문장2
[4.1, 4.2, 4.3, 4.4],
[5.1, 5.2, 5.3, 5.4],
[6.1, 6.2, 6.3, 6.4],
]
] # shape: [2, 3, 4]unsqueeze(1)로 차원 맞추기
pooled_input.unsqueeze(1) [2, 4] → [2, 1, 4]cat으로 합치기
torch.cat([pooled_input.unsqueeze(1), topk_concepts], dim=1) [2, 1, 4]와 [2, 3, 4]를 dim=1(구문 차원)으로 합침 [2, 4, 4]concat_pooled_concepts:[
[ # 문장1
[0.1, 0.2, 0.3, 0.4], # 문장 전체 벡터
[1.1, 1.2, 1.3, 1.4], # 1번째 개념
[2.1, 2.2, 2.3, 2.4], # 2번째 개념
[3.1, 3.2, 3.3, 3.4], # 3번째 개념
],
[ # 문장2
[0.5, 0.6, 0.7, 0.8], # 문장 전체 벡터
[4.1, 4.2, 4.3, 4.4],
[5.1, 5.2, 5.3, 5.4],
[6.1, 6.2, 6.3, 6.4],
]
] # shape: [2, 4, 4]attended_concepts, _ = self.multihead_attention(
query=concat_pooled_concepts,
key=concat_pooled_concepts,
value=concat_pooled_concepts
)
멀티헤드 어텐션 적용
query, key, value 모두 concat_pooled_concepts attended_concepts: [batch_size, topk+1, hidden_dim] gil_topk_logits = self.topk_gil_mlp(attended_concepts[:,0,:])
return gil_topk_logits, topk_indices

인코더 레이어:
Linear Logits (중앙):
Local Interpretable Layer (LIL) (왼쪽):
Global Interpretable Layer (GIL) (오른쪽):
주요 목적: 클래스 예측의 조건부 로그 우도 (conditional log-likelihood)를 최대화합니다.
예측 레이어 포함: linear (레이블 예측용), LIL (Local Interpretability Layer), GIL (Global Interpretability Layer) 모두에서 각각의 출력에 대해 학습합니다.
정규화(regularization): LIL, GIL 레이어의 출력이 최종 태스크 성능 향상에 기여하도록 정규화합니다.
최종 손실 는 세 부분의 가중 합이며, 는 정규화 하이퍼파라미터.
linear 레이어의 레이블 예측 손실:
GIL 손실 :
LIL 손실 :
