1줄 요약: 앗! prompt tuning, finetuning보다 낫다!
Mozart was born in ___.
이라는 prompting은 Mozart의 DoB를 뽑아내기 위해 cloze question style로 치환한 것임!AutoPrompt
__(x) performed until his death in __(y)
: 앞의 blank가 남자라면 ok이지만, 앞의 blank가 여자라면....?discrete prompting (iPET)
It is [MASK]
etc.X. It is [MASK]
에서 [MASK]
를 예측verbalizer
가 많이 사용되었음model size가 10B보다 크면 finetuning outperform하긴 하는데, 10B 밑의 모델만 해도 BERT family 다 해당되고, 범용적으로 쓰이는 LM들 다 들어가 있는데 prompt tuning이 large model에서 superiority를 가져간다 한들 대체 무슨 의미?
sequence tagging 같은 건 iPet의 verbalizer같은 것으로는 해결이 매우 힘듦
prefix를 transformer에 통과시킨 이후 reparameterization encoder를 통과시켜서 model prediction하게 되는데, 이전에는 이게 보통 MLP였음: 그런데 task와 dataset에 따라서 들쭉날쭉한 성능을 보였음
prompt length에 따라서 task의 성능이 달라짐; simple classification task는 shorter prompt(20)에서, hard classification task는 longer prompt(100)에서 더 좋은 성능
본 논문에서 시도되지는 않았지만, 가능은 할 듯
Schick and Schutze. (2020) 에서 제시된 prediction y
를 prompt에 맞게 재구성한 verbalizer
의 사용은 지금까지 prompt tuning의 핵심이었음. 그런데 full-data setting에서는 필요없고, sequence labeling에서는 잘 작동 안 하더라.
그래서 P-tuning v2에서는 [CLS] 토큰 앞에 randomly-initialized classification head를 붙였음.
SuperGLUE datasets + sequence labeling task (NER, extractive Question Answering, semantic role labeling)
BERT-large, RoBERTa-large, DeBERTa-xlarge, GLM-xlarge/xxlarge
model size ranges in (300M, 10B)
Verbalizer with LM head: [MASK]의 representation에 LM 하나 더 통과시켜서 나오는 것
[CLS] label with linear head: [CLS]의 representation에 linear layer 통과시켜서 나오는 것
Verbalizer with LM head (iPET 논문: Schick and Schutze, 2021) 가 이전까지는 많이 쓰였던 형식이었음. P-tuning v2에서는 대신 linear head를 사용
결과만 놓고 보면 큰 차이는 없음: 그래서 쓰기도 귀찮고 어려운 verbalizer를 굳이 사용할 필요 없다
이전 연구와의 또다른 차이점은 중간 layers에도 prompting이 들어간다는 것. 그래서 embedding에 가까운 layer 몇 개(ascending order), 마지막에 가까운 layer 몇 개(descending order)에다가 prompts 넣어서 실험 돌림
결과는, 마지막 몇 개 layer에 prompting하는 게 전체 layer에 prompting하는 것과 흡사한 결과를 낸다는 것
왜 p-tuning v2냐?
BERT, roBERTa 죄다 base 말고 large로 해놓고 "universally on model scale"이라고 뻔뻔하게 얘기할 수 있나?
후속 연구:
Li, J., Cotterell, R., & Sachan, M. (2022). Probing via Prompting. arXiv preprint arXiv:2207.01736. (NASCL 2022)
prefix_encoder.py
import torch
class PrefixEncoder(torch.nn.Module):
r'''
The torch.nn model to encode the prefix
Input shape: (batch-size, prefix-length)
Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values
sequence_classification.py
)class BertPromptForSequenceClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.embeddings = self.bert.embeddings
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
for param in self.bert.parameters():
param.requires_grad = False
self.pre_seq_len = config.pre_seq_len
self.n_layer = config.num_hidden_layers
self.n_head = config.num_attention_heads
self.n_embd = config.hidden_size // config.num_attention_heads
self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
prompts = self.prefix_encoder(prefix_tokens)
return prompts
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size = input_ids.shape[0]
raw_embedding = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
)
prompts = self.get_prompt(batch_size=batch_size)
inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
outputs = self.bert(
# input_ids,
attention_mask=attention_mask,
# token_type_ids=token_type_ids,
# position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
# past_key_values=past_key_values,
)
# pooled_output = outputs[1]
sequence_output = outputs[0]
sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
first_token_tensor = sequence_output[:, 0]
pooled_output = self.bert.pooler.dense(first_token_tensor)
pooled_output = self.bert.pooler.activation(pooled_output)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
저도 얼마 전에 prefix tuning를 사용하는 논문을 읽었는데, 이 논문을 접하니 좀 더 원론적인 이해가 잘 되었습니다 ^00^