about shuffle of DataLoader..

d4r6j·2023년 9월 25일

error log

목록 보기
1/3

Test code

import torchvision.transforms as transforms
from torchvision import datasets

from tqdm import tqdm
from torch.utils.data import DataLoader

DATA_PATH = "/data/images/stl10"

data_train = datasets.STL10(DATA_PATH,
                    split='train',
                    download=True)

transform = transforms.Compose({
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=None)
})

data_train.transform = transform
data_train = DataLoader(data_train, batch_size=16, shuffle=True)

train_progress = tqdm(iterable=data_train
                        , bar_format="{l_bar}{bar:25}{r_bar}"
                        , colour="green"
                        , total=len(data_train)
                        , leave=True)
tr_step = 0
for epoch in range(2):
    for tr_data, tr_target in train_progress:
        tr_step += 1
        data = tr_data.to("cpu")
        target = tr_target.to("cpu")
  • shuffle=True

    epochtr_target
    0tensor([6, 8, 8, 6, 5, 5, 5, 9, 4, 7, 4, 1, 8, 4, 4, 0])
    1tensor([4, 4, 1, 3, 9, 5, 6, 8, 5, 1, 9, 3, 0, 9, 9, 4])
  • shuffle=False

    epochtr_target
    0tensor([1, 5, 1, 6, 3, 9, 7, 4, 5, 8, 0, 6, 0, 8, 7, 6]))
    1tensor([1, 5, 1, 6, 3, 9, 7, 4, 5, 8, 0, 6, 0, 8, 7, 6])

epoch 윗단에서 DataLoader 를 호출 하고 shuffle 설정으로 메모리 load 하여 계속 고정되게 반복 사용하는 줄 알았다. 그런데, datapipe 라는 것이 있고,

# torch/utils/data/graph_settings.py

def apply_shuffle_settings(datapipe: DataPipe, shuffle: Optional[bool] = None) -> DataPipe:
    r"""
    Traverse the graph of ``DataPipes`` to find and set shuffle attribute
    to each `DataPipe` that has APIs of ``set_shuffle`` and ``set_seed``.

    Args:
        datapipe: DataPipe that needs to set shuffle attribute
        shuffle: Shuffle option (default: ``None`` and no-op to the graph)
    """
    if shuffle is None:
        return datapipe

    graph = traverse_dps(datapipe)
    all_pipes = get_all_graph_pipes(graph)
    shufflers = [pipe for pipe in all_pipes if _is_shuffle_datapipe(pipe)]
    if not shufflers and shuffle:
        warnings.warn(
            "`shuffle=True` was set, but the datapipe does not contain a `Shuffler`. Adding one at the end. "
            "Be aware that the default buffer size might not be sufficient for your task."
        )
        datapipe = datapipe.shuffle()
        shufflers = [datapipe, ]  # type: ignore[list-item]

    for shuffler in shufflers:
        shuffler.set_shuffle(shuffle)

    return datapipe

DataLoader 옵션이 shuffle=True 로 설정되면 datapipe 에 전달되고, epoch 으로 새로 data 를 load 하기 시작할 때, datapipe 의 기 설정된 shuffle 이 콜 되서 다시 셔플 되고 넘어가게 된다.

  • torch.manual_seed(3355)

    epochtr_target
    0tensor([6, 4, 8, 8, 7, 7, 6, 1, 7, 3, 9, 8, 6, 3, 4, 1])
    1tensor([8, 8, 2, 9, 8, 6, 6, 9, 6, 0, 7, 6, 9, 4, 4, 2])

    seed 를 고정 시킬 경우, epoch 할 때마다는 바뀌지만, 그 순서는 변하지 않는다. 즉, 그 seed 를 가지고 한번 더 shuffle 하는 것 같다. 좀 더 들어가볼 필요가 있을 듯.

0개의 댓글