250123 TIL #597 AI Tech #129 P:LightGCN 구현 시도

김춘복·2025년 1월 23일
0

TIL : Today I Learned

목록 보기
599/604

Today I Learned

오늘은 lightgcn 모델 구현! 아직 완성은 못하고 구현중!


LightGCN 구현

모델

  • class LightGCN
class LightGCN(nn.Module):
    def __init__(self, n_users, n_items, embedding_dim=16, n_layers=2, dropout=0.1, device='cuda'):
        super(LightGCN, self).__init__()
        self.n_users = n_users
        self.n_items = n_items
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        self.dropout = dropout
        self.device = device

        self.user_embedding = nn.Embedding(n_users, embedding_dim).to(self.device)
        self.item_embedding = nn.Embedding(n_items, embedding_dim).to(self.device)
        
        nn.init.xavier_uniform_(self.user_embedding.weight)
        nn.init.xavier_uniform_(self.item_embedding.weight)

    def get_ego_embeddings(self):
        user_embeddings = self.user_embedding.weight
        item_embeddings = self.item_embedding.weight
        ego_embeddings = torch.cat([user_embeddings, item_embeddings], dim=0)
        return ego_embeddings

    def forward(self, adj_mat):
        adj_mat = adj_mat.to(self.device)
        ego_embeddings = self.get_ego_embeddings()
        all_embeddings = [ego_embeddings]
        
        for k in range(self.n_layers):
            ego_embeddings = torch.sparse.mm(adj_mat, ego_embeddings)
            if self.training and self.dropout != 0:
                ego_embeddings = F.dropout(ego_embeddings, p=self.dropout, training=self.training)
            all_embeddings.append(ego_embeddings)

        all_embeddings = torch.stack(all_embeddings, dim=1)
        all_embeddings = torch.mean(all_embeddings, dim=1)

        user_all_embeddings, item_all_embeddings = torch.split(all_embeddings, [self.n_users, self.n_items])
        return user_all_embeddings, item_all_embeddings

    def calculate_loss(self, user_embeddings, item_embeddings, users, pos_items, neg_items):
        users = users.to(self.device)
        pos_items = pos_items.to(self.device)
        neg_items = neg_items.to(self.device)
        
        user_embeddings = user_embeddings[users]
        pos_scores = torch.sum(torch.mul(user_embeddings, item_embeddings[pos_items]), dim=1)
        neg_scores = torch.sum(torch.mul(user_embeddings, item_embeddings[neg_items]), dim=1)
        
        loss = -torch.mean(torch.log(torch.sigmoid(pos_scores - neg_scores)))
        return loss

    def predict(self, user_embeddings, item_embeddings, users, k=10, 
                user_encoded2user=None, anime_encoded2anime=None):
        if user_encoded2user is None or anime_encoded2anime is None:
            raise ValueError("user_encoded2user and anime_encoded2anime mappings must be provided")
        
        users = users.to(self.device)
        user_embeddings = user_embeddings[users]
        
        batch_size = 256
        n_users = len(users)
        topk_items = []
        
        with torch.no_grad():
            for start in tqdm(range(0, n_users, batch_size), desc="Generating predictions"):
                end = min(start + batch_size, n_users)
                batch_users = user_embeddings[start:end]
                
                scores = torch.matmul(batch_users, item_embeddings.t())
                _, top_items = torch.topk(scores, k=k, dim=1)
                topk_items.append(top_items)
        
        topk_items = torch.cat(topk_items, dim=0)
        
        results = []
        for idx, user_idx in tqdm(enumerate(users.cpu()), total=len(users), desc="Converting to DataFrame"):
            user_id = user_encoded2user[user_idx.item()]
            for rank, item_idx in enumerate(topk_items[idx].cpu()):
                results.append({
                    'user_id': user_id,
                    'anime_id': anime_encoded2anime[item_idx.item()],
                    'rank': rank + 1
                })
        
        return pd.DataFrame(results)
  • dataset
class LightGCNDataset(Dataset):
    def __init__(self, train_data, n_items):
        pos_data = train_data[train_data['interaction'] == 1]
        self.users = pos_data['user'].values
        self.pos_items = pos_data['item'].values
        
        neg_data = train_data[train_data['interaction'] == 0]
        neg_data_sampled = neg_data.sample(frac=0.5, random_state=42)
        self.neg_items = neg_data_sampled['item'].values
        
    def __len__(self):
        return len(self.users)
    
    def __getitem__(self, idx):
        user = self.users[idx]
        pos_item = self.pos_items[idx]
        neg_item = self.neg_items[idx]
        
        return user, pos_item, neg_item
  • train&evaluate 함수
def build_adj_matrix(train_data, n_users, n_items):
    """
    사용자-아이템 상호작용 데이터를 바탕으로 인접 행렬 생성.
    """
    user_indices = train_data['user'].values
    item_indices = train_data['item'].values + n_users  # 아이템 인덱스를 유저 이후로 배치
    
    # (유저, 아이템) 상호작용을 1로 설정
    interactions = torch.ones(len(user_indices))

    # 희소 행렬 생성 (유저 + 아이템 크기만큼)
    adj_mat = sp.coo_matrix(
        (interactions, (user_indices, item_indices)),
        shape=(n_users + n_items, n_users + n_items),
        dtype=np.float32
    )

    # 대칭 행렬 생성 (무방향 그래프 형태)
    adj_mat = adj_mat + adj_mat.T.multiply(adj_mat.T > adj_mat)
    adj_mat = adj_mat.tocsr()

    # 행을 정규화 (D^(-1/2) * A * D^(-1/2))
    row_sum = np.array(adj_mat.sum(1))
    d_inv_sqrt = np.power(row_sum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    norm_adj_mat = d_mat_inv_sqrt @ adj_mat @ d_mat_inv_sqrt
    
    # CSR 행렬을 COO 형식으로 변환
    norm_adj_mat = norm_adj_mat.tocoo()
    
    # COO 형식에서 indices와 values 추출
    indices = torch.LongTensor([norm_adj_mat.row, norm_adj_mat.col])
    values = torch.FloatTensor(norm_adj_mat.data)

    return torch.sparse.FloatTensor(  # float32로 명시적 변환
        indices,
        values.to(torch.float32),  # float32로 변환
        torch.Size(norm_adj_mat.shape)
    )


# 모델 학습
def train_model(train_data, adj_mat, n_users, n_items, epochs=10, learning_rate=0.001, batch_size=256):
    dataset = LightGCNDataset(train_data, n_items)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    model = LightGCN(n_users, n_items).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Epoch 진행률 바 적용
    for epoch in tqdm(range(epochs), desc="Epoch Progress"):
        model.train()
        total_loss = 0
        
        # 미니배치 진행률 바 적용
        for users, pos_items, neg_items in tqdm(data_loader, desc=f"Training Epoch {epoch+1}/{epochs}", leave=False):
            user_emb, item_emb = model(adj_mat)
            loss = model.calculate_loss(user_emb, item_emb, users, pos_items, neg_items)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Loss: {total_loss:.4f}')
        # 모델 저장 (덮어쓰기)
        torch.save(model.state_dict(), save_path)
        print(f'Model saved at epoch {epoch+1} to {save_path}')
    
    print("Training completed.")
    return model

# 예측 및 CSV 저장
def evaluate_and_save(model, adj_mat, user_encoded2user, anime_encoded2anime, top_k=10):
    model.eval()
    # DataFrame을 tensor로 변환
    if isinstance(adj_mat, pd.DataFrame):
        adj_mat = torch.from_numpy(adj_mat.values)
    adj_mat = adj_mat.to(model.device)
    
    user_emb, item_emb = model(adj_mat)
    users = torch.arange(model.n_users)
    
    predictions = model.predict(user_emb, item_emb, users, top_k, user_encoded2user, anime_encoded2anime)
    return predictions
  • id 인코딩
# ID 인코딩
user_ids = train["user_id"].unique().tolist()
user2user_encoded = {x: i for i, x in enumerate(user_ids)}
user_encoded2user = {i: x for i, x in enumerate(user_ids)}

anime_ids = train["anime_id"].unique().tolist()
anime2anime_encoded = {x: i for i, x in enumerate(anime_ids)}
anime_encoded2anime = {i: x for i, x in enumerate(anime_ids)}

# ID를 인코딩된 값으로 변환
train["user"] = train["user_id"].map(user2user_encoded)
train["item"] = train["anime_id"].map(anime2anime_encoded)

n_users = len(user2user_encoded)
n_items = len(anime2anime_encoded)
profile
Backend Dev / Data Engineer

0개의 댓글