이진분휴시 마지막 노드 1개, bcewithlogitloss
다진분류시 마지막 노드 클래스 개수, crossentropy loss
pred, label 자료형과 shape은 위의 사진 참고
num_train = len(train_dataset)
indices = list(range(num_train))
split = int(np.floor(valid_size * num_train))
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory,
)
strata를 줄 수 있음
from sklearn.model_selection import train_test_split
dataset = TotalDataset
dataset_size = len(dataset)
train_idx, val_idx = train_test_split(np.arange(dataset_size),
test_size=validation_split,
shuffle=True,
random_state=random_seed,
stratify=TotalDataset.labels)
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(val_idx)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=4,
sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=4,
sampler=valid_sampler)