torch.utils.data.Subset
을 사용하면 __len__
이 호출되지 않습니다.
어떻게 알았냐면 거기에 버그가 있었는데 잘 돌아가다가 Subset 빼니까 터짐. 저도 알고 싶지 않았어요.
관련 실험 스크립트
"""
python scripts/test_subset_and_len.py \
2>&1 | tee output/test_subset_and_len.log
"""
import torch
class ExampleDataset(torch.utils.data.Dataset):
def __init__(self):
self.data = list(range(10))
def __len__(self):
print(f"[DEBUG] __len__() is called")
return len(self.data)
def __getitem__(self, index):
print(f"[DEBUG] __getitem__() is called")
return self.data[index]
if __name__ == '__main__':
print(f"[DEBUG] no subset experiment")
dataset = ExampleDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)
for i, batch in enumerate(dataloader):
print(f"[DEBUG] no subset i: {i}, batch: {batch}")
print("=" * 40)
print(f"[DEBUG] subset experiment")
dataset_subset = torch.utils.data.Subset(dataset, indices=range(3, 7))
dataloader_subset = torch.utils.data.DataLoader(dataset_subset, batch_size=2, shuffle=False)
for i, batch in enumerate(dataloader_subset):
print(f"[DEBUG] subset i: {i}, batch: {batch}")
결과
[DEBUG] no subset experiment
[DEBUG] __len__() is called
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 0, batch: tensor([0, 1])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 1, batch: tensor([2, 3])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 2, batch: tensor([4, 5])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 3, batch: tensor([6, 7])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 4, batch: tensor([8, 9])
========================================
[DEBUG] subset experiment
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] subset i: 0, batch: tensor([3, 4])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] subset i: 1, batch: tensor([5, 6])