torch에서 `Subset`과 `__len__`의 관계

양세종·2023년 11월 15일
0

torch.utils.data.Subset 을 사용하면 __len__이 호출되지 않습니다.
어떻게 알았냐면 거기에 버그가 있었는데 잘 돌아가다가 Subset 빼니까 터짐. 저도 알고 싶지 않았어요.

관련 실험 스크립트

"""
python scripts/test_subset_and_len.py \
2>&1 | tee output/test_subset_and_len.log
"""
import torch


class ExampleDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = list(range(10))
    
    def __len__(self):
        print(f"[DEBUG] __len__() is called")
        return len(self.data)
    
    def __getitem__(self, index):
        print(f"[DEBUG] __getitem__() is called")
        return self.data[index]


if __name__ == '__main__':
    print(f"[DEBUG] no subset experiment")
    dataset = ExampleDataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)
    for i, batch in enumerate(dataloader):
        print(f"[DEBUG] no subset i: {i}, batch: {batch}")
    print("=" * 40)
    print(f"[DEBUG] subset experiment")
    dataset_subset = torch.utils.data.Subset(dataset, indices=range(3, 7))
    dataloader_subset = torch.utils.data.DataLoader(dataset_subset, batch_size=2, shuffle=False)
    for i, batch in enumerate(dataloader_subset):
        print(f"[DEBUG] subset i: {i}, batch: {batch}")

결과

[DEBUG] no subset experiment
[DEBUG] __len__() is called
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 0, batch: tensor([0, 1])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 1, batch: tensor([2, 3])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 2, batch: tensor([4, 5])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 3, batch: tensor([6, 7])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 4, batch: tensor([8, 9])
========================================
[DEBUG] subset experiment
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] subset i: 0, batch: tensor([3, 4])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] subset i: 1, batch: tensor([5, 6])
profile
Researcher, Developer, Student

0개의 댓글