PyTorch - 2. Datasets & Dataloader

์ตœ์ฐฝ์šฐยท2022๋…„ 12์›” 5์ผ
0

PyTorch

๋ชฉ๋ก ๋ณด๊ธฐ
2/6
post-thumbnail

๐Ÿ“œ PyTorch Datasets & Dataloader ์ดํ•ด

๊นƒํ—ˆ๋ธŒ ์ฝ”๋“œ

๋ฐ์ดํ„ฐ๋ฅผ ํŽธ๋ฆฌํ•˜๊ฒŒ ๋‹ค๋ฃจ๊ธฐ ์œ„ํ•˜์—ฌ ๋ชจ๋ธ ๋ถ€๋ถ„๊ณผ ๋…๋ฆฝ์‹œ์ผœ์„œ ๋‹ค๋ค„์•ผํ•จ
PyTorch๋Š” ์‚ฌ์šฉ์ž๊ฐ€ ๊ตฌ์ถ•ํ•œ ๋ฐ์ดํ„ฐ์™€ ๋ฏธ๋ฆฌ ๋งŒ๋“ค์–ด์ง„ ๋ฐ์ดํ„ฐ๋ฅผ ์‚ฌ์šฉ๊ฐ€๋Šฅํ•˜๋„๋ก primitives(๊ธฐ์ดˆ์š”์†Œ)๋ฅผ ์ œ๊ณตํ•จ

  1. torch.utils.data.Dataset
  • Dataset์€ ์ƒ˜ํ”Œ๊ณผ ํ•ด๋‹น ๋ ˆ์ด๋ธ”์„ ์ €์žฅ
  • ํ•ด๋‹น ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ƒ์†๋ฐ›์•„์„œ ๋ฐ์ดํ„ฐ์„ธํŠธ ํด๋ž˜์Šค์ •์˜ํ•จ
  1. torch.utils.data.DataLoader
  • DataLoader๋Š” ์ƒ˜ํ”Œ์— ์‰ฝ๊ฒŒ ์•ก์„ธ์Šคํ•  ์ˆ˜ ์žˆ๋„๋ก Dataset ์ฃผ์œ„์— iterable์„ ๋ž˜ํ•‘

๐Ÿ“• ์•Œ์•„์•ผ ํ•  ํ•จ์ˆ˜

๐Ÿ“– torchvision.datasets

์‚ฌ์ „ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ์„ธํŠธ ๊ฐ€์ ธ์˜ฌ ์‹œ ์‚ฌ์šฉ

training_data = datasets.FashionMNIST(
    root="data", # train๊ณผ test๋ฐ์ดํ„ฐ๊ฐ€ ์ €์žฅ๋œ ๊ฒฝ๋กœ
    train=True, # ํ•™์Šต์ธ์ง€ ํ…Œ์ŠคํŠธ์ธ์ง€ ๋ช…์‹œ
    download=True, # root ๊ฒฝ๋กœ์—์„œ ๋ถˆ๋Ÿฌ์™€์ง€์ง€ ์•Š์„๊ฒฝ์šฐ 
				   # ์ธํ„ฐ๋„ท์—์„œ ๋‹ค์šด๋กœ๋“œํ•ด์˜ฌ๊ฒƒ์ธ์ง€?
    transform=transforms.ToTensor() # feature์™€ label ์ „์ฒ˜๋ฆฌ ํ•„์š”์‚ฌํ•ญ ๋ช…์‹œ
)

๐Ÿ“– torchvision.transforms

๋ฐ์ดํ„ฐ (์ด๋ฏธ์ง€์ „์šฉ) ์ „์ฒ˜๋ฆฌ ๋ชจ๋“ˆ ๋ชจ์Œ

  1. Compose
  • ๋ฐ์ดํ„ฐ์ „์ฒ˜๋ฆฌ๊ณผ์ • ์—ฌ๋Ÿฌ๊ฐœ์ผ๋•Œ ๋ฌถ์–ด์„œ ์‚ฌ์šฉ
  1. ToTensor()
  • PIL์ด๋ฏธ์ง€ or Numpy ๋ฐฐ์—ด --> Tensor ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
  • "H x W x C" โ†’ "C x H x W"
  • ์Šค์ผ€์ผ๋ง๋„ ๋˜์–ด 0~1 ๋กœ ๋ณ€ํ™˜๋œ๋‹ค. (ํ•ญ์ƒ๋˜๋Š”๊ฒƒ์€์•„๋‹˜)
  • numpy ๊ฒฝ์šฐ dtype = np.uint8 ์ผ๊ฒฝ์šฐ๋งŒ ์Šค์ผ€์ผ๋ง๋จ
  • ์ž์„ธํ•œ ์‚ฌํ•ญ์€ Document ์ฐธ๊ณ 
  1. Normalize(mean, std, inplace=False)
  • ์ •๊ทœํ™”์ˆ˜ํ–‰ (ํ‰๊ท ,ํ‘œ์ค€ํŽธ์ฐจ,๋Œ€์ฒด์—ฌ๋ถ€)
  1. ToPILImage()
  • csv ํŒŒ์ผ๋กœ ๋ฐ์ดํ„ฐ์…‹์„ ๋ฐ›์„ ๊ฒฝ์šฐ, PIL image๋กœ ๋ณ€ํ™˜
  1. Resize((300, 300))
  • ์ด๋ฏธ์ง€ ์‚ฌ์ด์ฆˆ๋ณ€ํ™˜
  1. RandomHorizontalFlip(p = ํ™•๋ฅ ๊ฐ’๊ธฐ์ž…)
  • ์ด๋ฏธ์ง€ ์ขŒ์šฐ๋Œ€์นญ ( p์˜ default ๊ฐ’ 0.5 )

๋ฉ”์†Œ๋“œ ๋„ˆ๋ฌด๋งŽ์Œ.. Document ์ฐธ๊ณ 

๐Ÿ“– torch.utils.data.DataLoader

์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ์„ธํŠธ ํด๋ž˜์Šค๊ฐ€ ๋‹ด๊ธด ๊ฐ์ฒด๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›์•„์„œ, ๋ฐ์ดํ„ฐ๋ฅผ iterableํ•˜๊ฒŒ Load ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•จ

# 1. ์‚ฌ์ „ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ์„ธํŠธ ํด๋ž˜์Šค ์‚ฌ์šฉ์‹œ
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

train_dataset = datasets.FashionMNIST(
    root="data",
    train=True, 
    download=True,
    transform=transforms.ToTensor()
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# 2. ์‚ฌ์šฉ์ž๊ฐ€ ์ •์˜ํ•œ ๋ฐ์ดํ„ฐ์„ธํŠธ ํด๋ž˜์Šค ์‚ฌ์šฉ์‹œ
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class CustomDataset(Dataset):
	'''
    '''
    return 
    
train_dataset = CustomDataset(train_path,
                              transform=transforms.Compose([transforms.ToTensor()]))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

๐Ÿ“– torch.randint(๊ธธ์ด,size=(a,b))

  • ๊ธธ์ด๋งŒํผ (a,b) shape์˜ ํ˜•ํƒœ๋กœ ๋žœ๋คํ•œ ์ •์ˆ˜ ๋ฆฌํ„ด
  • .item์œผ๋กœ ๋‚ด๋ถ€๊ฐ’์„ ๋ฝ‘์•„๋‚ผ ์ˆ˜ ์žˆ์Œ
x = torch.randint(len(training_data),size=(1,))

print(x)
print(x.item())

> tensor([12491])
> 12491

๐Ÿ“• ์‚ฌ์ „ ์ƒ์„ฑ๋œ ๋ฐ์ดํ„ฐ ์‚ฌ์šฉ

1. dataset ํด๋ž˜์Šค๊ฐ€ ๋‹ด๊ธด ๊ฐ์ฒด์ƒ์„ฑ
2. dataloader๋ฅผ ๋‹ด์€ ๊ฐ์ฒด์ƒ์„ฑ

from torchvision import datasets
import torchvision.transforms as transforms

# 1. dataset ํด๋ž˜์Šค๊ฐ€ ๋‹ด๊ธด ๊ฐ์ฒด ์ƒ์„ฑ
training_data = datasets.FashionMNIST(
    root="data", # train๊ณผ test๋ฐ์ดํ„ฐ๊ฐ€ ์ €์žฅ๋œ ๊ฒฝ๋กœ
    train=True, # ํ•™์Šต์ธ์ง€ ํ…Œ์ŠคํŠธ์ธ์ง€ ๋ช…์‹œ
    download=True, # root ๊ฒฝ๋กœ์—์„œ ๋ถˆ๋Ÿฌ์™€์ง€์ง€ ์•Š์„๊ฒฝ์šฐ ์ธํ„ฐ๋„ท์—์„œ ๋‹ค์šด๋กœ๋“œํ•ด์˜ฌ๊ฒƒ์ธ์ง€?
    transform=transforms.ToTensor() # feature์™€ label ์ „์ฒ˜๋ฆฌ ํ•„์š”์‚ฌํ•ญ ๋ช…์‹œ
)
test_data = datasets.FashionMNIST(
    root="data", # train๊ณผ test๋ฐ์ดํ„ฐ๊ฐ€ ์ €์žฅ๋œ ๊ฒฝ๋กœ
    train=False, # ํ•™์Šต์ธ์ง€ ํ…Œ์ŠคํŠธ์ธ์ง€ ๋ช…์‹œ
    download=True, # root ๊ฒฝ๋กœ์—์„œ ๋ถˆ๋Ÿฌ์™€์ง€์ง€ ์•Š์„๊ฒฝ์šฐ ์ธํ„ฐ๋„ท์—์„œ ๋‹ค์šด๋กœ๋“œํ•ด์˜ฌ๊ฒƒ์ธ์ง€?
    transform=transforms.ToTensor() # feature์™€ label ์ „์ฒ˜๋ฆฌ ํ•„์š”์‚ฌํ•ญ ๋ช…์‹œ
)

# 2. DataLoader๋ฅผ ๋‹ด์€ ๊ฐ์ฒด ์ƒ์„ฑ
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

# 3. iterable ํ•จ์œผ๋กœ, ๋ฐ์ดํ„ฐ ํ•œ๊ฐœ ๊ฐ€์ ธ์™€์„œ ํ”Œ๋กœํŒ…
input, output = next(iter(train_dataloader))
idx = 0
img = input[idx].squeeze()
label = output[idx]
plt.figure(figsize=[2,2])
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

๐Ÿ“• ์‚ฌ์šฉ์ž์ •์˜ ๋ฐ์ดํ„ฐ์„ธํŠธ ์‚ฌ์šฉ

1. dataset ํด๋ž˜์Šค ์ƒ์„ฑ
2. dataset ํด๋ž˜์Šค๊ฐ€ ๋‹ด๊ธด ๊ฐ์ฒด์ƒ์„ฑ
3. dataloader๋ฅผ ๋‹ด์€ ๊ฐ์ฒด์ƒ์„ฑ

from torch.utils.data import Dataset

# 1. ๋ฐ์ดํ„ฐ์„ธํŠธ ํด๋ž˜์Šค ์ƒ์„ฑ
class CustomDataset(Dataset):
    
    # ๊ฐ์ฒด์ƒ์„ฑ์‹œ ํ•œ๋ฒˆ์‹คํ–‰
    # ๋ฐ์ดํ„ฐ๊ฒฝ๋กœ์™€ transform ์ •์˜ ๋“ฑ..
    def __init__(self, data_path, transform=None, target_transform=None):
        
        data = pd.read_csv(data_path)
        Y_data = data['label']
        Y_data = np.array(Y_data)
        X_data = data.drop(columns='label',axis=1)
        X_data = np.array(X_data).reshape(-1,28,28,1).astype('float32')
        self.X_data = X_data
        self.Y_data = Y_data
        self.transform = transform
        self.target_transform = target_transform

    # ๋ฐ์ดํ„ฐ์„ธํŠธ์˜ ์ด ๊ฐœ์ˆ˜๋ฅผ ๋ฆฌํ„ด
    def __len__(self):
        return len(self.Y_data)

    # ์ธ๋ฑ์Šค๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฐ์ดํ„ฐ์„ธํŠธ๋กœ๋ถ€ํ„ฐ ์ƒ˜ํ”Œ์„ ๊ฐ€์ ธ์˜ค๋Š” ํ•จ์ˆ˜ 
    # ํŠœํ”Œํ˜•ํƒœ๋กœ ๋ฆฌํ„ด : (์ž…๋ ฅ,์ถœ๋ ฅ)
    def __getitem__(self, idx):            

        # ๋ฐ์ดํ„ฐ ์ง€์ •
        image = self.X_data[idx]
        label = self.Y_data[idx]
        
        # ์ž…๋ ฅ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ
        if self.transform:
            image = self.transform(image)

        # ์ถœ๋ ฅ ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ
        if self.target_transform:
            label = self.target_transform(label)

        return image, label


# 2. dataset ํด๋ž˜์Šค๊ฐ€ ๋‹ด๊ธด ๊ฐ์ฒด ์ƒ์„ฑ
import torchvision.transforms as transforms

train_path = "dataset/fashion-mnist_train.csv"
test_path = "dataset/fashion-mnist_test.csv"
train_dataset = CustomDataset(train_path,
                              transform=transforms.Compose([transforms.ToTensor()]))
test_dataset = CustomDataset(test_path,
                             transform=transforms.Compose([transforms.ToTensor()]))
                             
# 3. DataLoader๋ฅผ ๋‹ด์€ ๊ฐ์ฒด ์ƒ์„ฑ
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# 4. iterable ํ•จ์œผ๋กœ, ๋ฐ์ดํ„ฐ ํ•œ๊ฐœ ๊ฐ€์ ธ์™€์„œ ํ”Œ๋กœํŒ…
input, output = next(iter(train_dataloader))
idx = 0
img = input[idx].squeeze()
label = output[idx]
plt.figure(figsize=[2,2])
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
profile
์œ ๋Šฅํ•œ ๊ฐœ๋ฐœ์ž๊ฐ€ ๋˜๊ณ  ์‹ถ์€ ํ—ฌ๋ฆฐ์ด

0๊ฐœ์˜ ๋Œ“๊ธ€