이번엔 파이토치의 dataset class를 상속받아 사용자 정의 dataset을 만들것이다. 위 작업은 꼭 필요한데 우선 각 프로젝트의 도메인이 다르고 데이터의 저장방법이 다르기 때문에 사용자의 프로젝트에 따라 데이터셋을 커스텀해야한다.
주의 할 점은 상속받아서 재정의 할 때 def len(self): 과 def getitem(self, index):은 무조건 재정의 해 줘야한다.
import pandas as pd
import torchaudio
from torch.utils.data import Dataset
class SoundDataset(Dataset):
"""
A dataset class for loading and processing sound files.
Attributes:
annotations (list): A list containing annotation data for each audio file.
Methods:
__len__(): Returns the total number of audio files in the dataset.
__getitem__(index): Returns the audio signal and label for a given index.
_get_audio_path(index): Returns the path to the audio file at the given index.
_get_audio_label(index): Returns the label of the audio file at the given index.
"""
def __init__(self, annotations_files):
"""
Initializes the SoundDataset with the given annotations.
Parameters:
annotations_files (list): A list of annotation data where each annotation is expected
to be in the format [id, path, label].
"""
self.annotations = pd.read_csv(annotations_files)
def __len__(self):
"""
Returns the total number of audio files in the dataset.
Returns:
int: The number of audio files.
"""
return len(self.annotations)
def __getitem__(self, index):
"""
Returns the audio signal and label for a given index.
Parameters:
index (int): The index of the audio file.
Returns:
tuple: A tuple containing the audio signal (torch.Tensor) and the label.
"""
audio_path = self._get_audio_path(index)
label = self._get_audio_label(index)
signal, sr = torchaudio.load(audio_path)
return signal, label
def _get_audio_path(self, index):
"""
Returns the path to the audio file at the given index.
Parameters:
index (int): The index of the audio file.
Returns:
str: The file path of the audio file.
"""
path = self.annotations.iloc[index, 1]
return path
def _get_audio_label(self, index):
"""
Returns the label of the audio file at the given index.
Parameters:
index (int): The index of the audio file.
Returns:
str: The label of the audio file.
"""
return self.annotations.iloc[index, 2]
이번 프로젝트의 데이터 라벨링은 pandas의 csv를 이용하였다. column은 index, file_path, label 순서이고 이를 가져오는 방식도 iloc을 이용하였다.
import unittest
import pandas as pd
import os
from unittest.mock import patch
import torch
from data.sound_dataset import SoundDataset
class TestSoundDataset(unittest.TestCase):
def setUp(self):
# 가상의 주석 데이터 CSV 파일 생성
self.annotations = pd.DataFrame([
[0, "fake_path_1.wav", "label_1"],
[1, "fake_path_2.wav", "label_2"]
], columns=["id", "path", "label"])
self.annotations_file = "test_annotations.csv"
self.annotations.to_csv(self.annotations_file, index=False)
# SoundDataset 객체 생성
self.dataset = SoundDataset(self.annotations_file)
def tearDown(self):
# 테스트 후 생성된 파일을 삭제합니다.
os.remove(self.annotations_file)
def test_len(self):
self.assertEqual(len(self.dataset), 2)
def test_getitem(self):
with patch('torchaudio.load') as mocked_load:
mocked_load.return_value = (torch.zeros((1, 16000)), 16000)
signal, label = self.dataset[0]
mocked_load.assert_called_once_with("fake_path_1.wav")
self.assertTrue(torch.equal(signal, torch.zeros((1, 16000))))
self.assertEqual(label, "label_1")
def test_get_audio_path(self):
path = self.dataset._get_audio_path(0)
self.assertEqual(path, "fake_path_1.wav")
def test_get_audio_label(self):
label = self.dataset._get_audio_label(0)
self.assertEqual(label, "label_1")
if __name__ == '__main__':
unittest.main()
딥러닝 모델을 가동할 때 training set, validation set을 나눠서 테스트 하는 데 이것을 책임질 class를 만들자. 중요한 매개변수는 조정가능하도록 했다.
from torch.utils.data import Dataset, DataLoader, random_split
class DataSplitter:
"""
A class to split a dataset into training and validation sets.
Attributes:
dataset (Dataset): The dataset to be split.
train_set (Dataset): The training subset.
val_set (Dataset): The validation subset.
Methods:
split_dataset(train_ratio): Splits the dataset into training and validation sets.
get_loaders(train_batch_size, val_batch_size, num_workers): Returns DataLoaders for the splits.
"""
def __init__(self, dataset):
"""
Initializes the DataSplitter with the given dataset.
Parameters:
dataset (Dataset): The dataset to be split.
"""
self.dataset = dataset
self.train_set = None
self.val_set = None
def split_dataset(self, train_ratio):
"""
Splits the dataset into training and validation sets based on the given ratio.
Parameters:
train_ratio (float): The ratio of the training set. Default is 0.8.
"""
train_size = int(train_ratio * len(self.dataset))
val_size = len(self.dataset) - train_size
self.train_set, self.val_set = random_split(self.dataset, [train_size, val_size])
def get_loaders(self, train_batch_size, val_batch_size, num_workers):
"""
Returns DataLoaders for the training and validation sets.
Parameters:
train_batch_size (int): Batch size for the training set. Default is 32.
val_batch_size (int): Batch size for the validation set. Default is 32.
num_workers (int): Number of worker threads to use for data loading. Default is 2.
Returns:
tuple: A tuple containing the training and validation DataLoaders.
"""
train_loader = DataLoader(self.train_set, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(self.val_set, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)
return train_loader, val_loader
from data.data_splitter import DataSplitter
from data.sound_dataset import SoundDataset
import unittest
from unittest.mock import patch, MagicMock
from io import StringIO
import torch
class TestDataSplitter(unittest.TestCase):
def setUp(self):
csv_data = """index,path,label
1,path/to/audio1.wav,label1
2,path/to/audio2.wav,label1
3,path/to/audio3.wav,label1
4,path/to/audio4.wav,label2
5,path/to/audio5.wav,label2
6,path/to/audio6.wav,label2
7,path/to/audio7.wav,label3
8,path/to/audio8.wav,label3
9,path/to/audio9.wav,label3
10,path/to/audio10.wav,label3"""
self.annotations_file = StringIO(csv_data)
self.dataset = SoundDataset(self.annotations_file)
@patch('torchaudio.load')
def test_split_dataset(self, mock_load):
# Mock torchaudio.load 반환 값 설정
mock_load.return_value = (torch.zeros(1, 16000), 16000)
splitter = DataSplitter(self.dataset)
splitter.split_dataset(train_ratio=0.8)
self.assertEqual(len(splitter.train_set), 8)
self.assertEqual(len(splitter.val_set), 2)
@patch('torchaudio.load')
def test_get_loaders(self, mock_load):
# Mock torchaudio.load 반환 값 설정
mock_load.return_value = (torch.zeros(1, 16000), 16000)
splitter = DataSplitter(self.dataset)
splitter.split_dataset(train_ratio=0.8)
train_loader, val_loader = splitter.get_loaders(train_batch_size=2, val_batch_size=1, num_workers=0)
train_batches = list(train_loader)
val_batches = list(val_loader)
self.assertEqual(len(train_batches), 4)
self.assertEqual(len(val_batches), 2)
if __name__ == '__main__':
unittest.main()