[논문리뷰]Bert4Rec

김동환·2023년 6월 7일
0

AI_tech_5기

목록 보기
18/18

Abstract

기존의 Sequential 방법론들은 단방향으로만 학습이 이루어져 몇가지 제한사항이 있다고 한다. 실제 유저의 행동 sequence가 꼭 한 방향으로만 이루어진다고 장담할 수도 없고 hidden representation이 효과적으로 이루어지지 않을 수도 있다.
그렇기에 이 논문은 bidirectional self attention과 Cloze obj를 이용해 더 효과적으로 학습하는 방법을 제시한다.

Cloze task

Bidirectinal한 학습을 위해서 일반적으로 shift를 해서 마지막 이이템에 대한 예측을 하는 목적 함수와 달리 Cloze task는 random하게 마스킹을 하고 주변의 seq를 이용해서 mask 된 부분을 예측한다. 이렇게 하면 양방향적인 학습 뿐만 아니라 데이터 양 또한 늘어나는 장점이 있다.
하단 코드를 보면 랜덤하게 수를 추출하고 0.9 이상이면 마스킹을 한다. 마스킹을 할 때는 아이템 ID + 1을 해서 이를 마스킹 인덱스로 활용하고 있다.

class SeqDataset(Dataset):
   def __init__(self, user_train, num_user, num_item, max_len, mask_prob):
       self.user_train = user_train
       self.num_user = num_user
       self.num_item = num_item
       self.max_len = max_len
       self.mask_prob = mask_prob
def __len__(self):
    # 총 user의 수 = 학습에 사용할 sequence의 수
    return self.num_user

def __getitem__(self, user): 

    seq = self.user_train[user]
    tokens = []
    labels = []
    for s in seq:
        prob = np.random.random() 
        if prob < self.mask_prob:
            prob /= self.mask_prob

            # 랜덤하게 마스킹
            if prob < 0.8:
                # masking
                tokens.append(self.num_item + 1)  # mask_index: num_item + 1, 0: pad, 1~num_item: item index
            elif prob < 0.9:
                tokens.append(np.random.randint(1, self.num_item+1))  # item random sampling
            else:
                tokens.append(s)
            labels.append(s)  # 학습에 사용
        else:
            tokens.append(s)
            labels.append(0)  # 학습에 사용 X, trivial
    tokens = tokens[-self.max_len:]
    labels = labels[-self.max_len:]
    mask_len = self.max_len - len(tokens)

    # zero padding
    tokens = [0] * mask_len + tokens
    labels = [0] * mask_len + labels
    return torch.LongTensor(tokens), torch.LongTensor(labels)
    

Transformer Block

SASRec과 다르게 transformer block들도 양방향으로 전이가 된다. 또한 이를 여러 층 쌓음으로서 더욱 효과적으로 학습할 수 있도록 한다.

#Output

마지막에는 Projection layer와 Gelu를 거친 후에 Item Embedding과 행렬 곱 후 softmax를 취해서 각 아이템에 대한 확률 분포를 취한다.

학습

[v1, v2, v3, v4, v5] 가 들어오면 [v1, [mask], v3, [mask], v5] 이렇게 마스킹을 한다.
앞에서 데이터로더에서 0으로 처리한 부분은 CrossEntropyLoss 구할 때 무시가 되므로 마스킹된 부분에 대해서 loss를 구하고 업데이트를 하게 된다.

profile
AI Engineer

0개의 댓글