음악 분류 딥러닝을 만들자(7) - 데이터셋과 training, validation set

응큼한포도·2024년 6월 22일
0

Dataset의 커스텀마이징하자

pytorch dataset

이번엔 파이토치의 dataset class를 상속받아 사용자 정의 dataset을 만들것이다. 위 작업은 꼭 필요한데 우선 각 프로젝트의 도메인이 다르고 데이터의 저장방법이 다르기 때문에 사용자의 프로젝트에 따라 데이터셋을 커스텀해야한다.

주의 할 점은 상속받아서 재정의 할 때 def len(self): 과 def getitem(self, index):은 무조건 재정의 해 줘야한다.

Dataset

import pandas as pd
import torchaudio
from torch.utils.data import Dataset


class SoundDataset(Dataset):
    """
    A dataset class for loading and processing sound files.

    Attributes:
    annotations (list): A list containing annotation data for each audio file.

    Methods:
    __len__(): Returns the total number of audio files in the dataset.
    __getitem__(index): Returns the audio signal and label for a given index.
    _get_audio_path(index): Returns the path to the audio file at the given index.
    _get_audio_label(index): Returns the label of the audio file at the given index.
    """

    def __init__(self, annotations_files):
        """
        Initializes the SoundDataset with the given annotations.

        Parameters:
        annotations_files (list): A list of annotation data where each annotation is expected
                                  to be in the format [id, path, label].
        """
        self.annotations = pd.read_csv(annotations_files)
    def __len__(self):
        """
        Returns the total number of audio files in the dataset.

        Returns:
        int: The number of audio files.
        """
        return len(self.annotations)

    def __getitem__(self, index):
        """
        Returns the audio signal and label for a given index.

        Parameters:
        index (int): The index of the audio file.

        Returns:
        tuple: A tuple containing the audio signal (torch.Tensor) and the label.
        """
        audio_path = self._get_audio_path(index)
        label = self._get_audio_label(index)
        signal, sr = torchaudio.load(audio_path)
        return signal, label

    def _get_audio_path(self, index):
        """
        Returns the path to the audio file at the given index.

        Parameters:
        index (int): The index of the audio file.

        Returns:
        str: The file path of the audio file.
        """
        path = self.annotations.iloc[index, 1]
        return path

    def _get_audio_label(self, index):
        """
        Returns the label of the audio file at the given index.

        Parameters:
        index (int): The index of the audio file.

        Returns:
        str: The label of the audio file.
        """
        return self.annotations.iloc[index, 2]

이번 프로젝트의 데이터 라벨링은 pandas의 csv를 이용하였다. column은 index, file_path, label 순서이고 이를 가져오는 방식도 iloc을 이용하였다.

테스트 코드

import unittest
import pandas as pd
import os
from unittest.mock import patch
import torch
from data.sound_dataset import SoundDataset 

class TestSoundDataset(unittest.TestCase):
    def setUp(self):
        # 가상의 주석 데이터 CSV 파일 생성
        self.annotations = pd.DataFrame([
            [0, "fake_path_1.wav", "label_1"],
            [1, "fake_path_2.wav", "label_2"]
        ], columns=["id", "path", "label"])
        self.annotations_file = "test_annotations.csv"
        self.annotations.to_csv(self.annotations_file, index=False)

        # SoundDataset 객체 생성
        self.dataset = SoundDataset(self.annotations_file)

    def tearDown(self):
        # 테스트 후 생성된 파일을 삭제합니다.
        os.remove(self.annotations_file)

    def test_len(self):
        self.assertEqual(len(self.dataset), 2)

    def test_getitem(self):
        with patch('torchaudio.load') as mocked_load:
            mocked_load.return_value = (torch.zeros((1, 16000)), 16000)
            signal, label = self.dataset[0]
            mocked_load.assert_called_once_with("fake_path_1.wav")
            self.assertTrue(torch.equal(signal, torch.zeros((1, 16000))))
            self.assertEqual(label, "label_1")

    def test_get_audio_path(self):
        path = self.dataset._get_audio_path(0)
        self.assertEqual(path, "fake_path_1.wav")

    def test_get_audio_label(self):
        label = self.dataset._get_audio_label(0)
        self.assertEqual(label, "label_1")

if __name__ == '__main__':
    unittest.main()

커스텀한 데이터 셋을 training과 validation set으로 나누자

딥러닝 모델을 가동할 때 training set, validation set을 나눠서 테스트 하는 데 이것을 책임질 class를 만들자. 중요한 매개변수는 조정가능하도록 했다.

from torch.utils.data import Dataset, DataLoader, random_split

class DataSplitter:
    """
    A class to split a dataset into training and validation sets.

    Attributes:
    dataset (Dataset): The dataset to be split.
    train_set (Dataset): The training subset.
    val_set (Dataset): The validation subset.

    Methods:
    split_dataset(train_ratio): Splits the dataset into training and validation sets.
    get_loaders(train_batch_size, val_batch_size, num_workers): Returns DataLoaders for the splits.
    """

    def __init__(self, dataset):
        """
        Initializes the DataSplitter with the given dataset.

        Parameters:
        dataset (Dataset): The dataset to be split.
        """
        self.dataset = dataset
        self.train_set = None
        self.val_set = None

    def split_dataset(self, train_ratio):
        """
        Splits the dataset into training and validation sets based on the given ratio.

        Parameters:
        train_ratio (float): The ratio of the training set. Default is 0.8.
        """
        train_size = int(train_ratio * len(self.dataset))
        val_size = len(self.dataset) - train_size
        self.train_set, self.val_set = random_split(self.dataset, [train_size, val_size])

    def get_loaders(self, train_batch_size, val_batch_size, num_workers):
        """
        Returns DataLoaders for the training and validation sets.

        Parameters:
        train_batch_size (int): Batch size for the training set. Default is 32.
        val_batch_size (int): Batch size for the validation set. Default is 32.
        num_workers (int): Number of worker threads to use for data loading. Default is 2.

        Returns:
        tuple: A tuple containing the training and validation DataLoaders.
        """
        train_loader = DataLoader(self.train_set, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
        val_loader = DataLoader(self.val_set, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)
        return train_loader, val_loader

단위 테스트

from data.data_splitter import DataSplitter
from data.sound_dataset import SoundDataset
import unittest
from unittest.mock import patch, MagicMock
from io import StringIO
import torch

class TestDataSplitter(unittest.TestCase):

    def setUp(self):
        csv_data = """index,path,label
                      1,path/to/audio1.wav,label1
                      2,path/to/audio2.wav,label1
                      3,path/to/audio3.wav,label1
                      4,path/to/audio4.wav,label2
                      5,path/to/audio5.wav,label2
                      6,path/to/audio6.wav,label2
                      7,path/to/audio7.wav,label3
                      8,path/to/audio8.wav,label3
                      9,path/to/audio9.wav,label3
                      10,path/to/audio10.wav,label3"""
        self.annotations_file = StringIO(csv_data)
        self.dataset = SoundDataset(self.annotations_file)

    @patch('torchaudio.load')
    def test_split_dataset(self, mock_load):
        # Mock torchaudio.load 반환 값 설정
        mock_load.return_value = (torch.zeros(1, 16000), 16000)

        splitter = DataSplitter(self.dataset)
        splitter.split_dataset(train_ratio=0.8)

        self.assertEqual(len(splitter.train_set), 8)
        self.assertEqual(len(splitter.val_set), 2)

    @patch('torchaudio.load')
    def test_get_loaders(self, mock_load):
        # Mock torchaudio.load 반환 값 설정
        mock_load.return_value = (torch.zeros(1, 16000), 16000)

        splitter = DataSplitter(self.dataset)
        splitter.split_dataset(train_ratio=0.8)
        train_loader, val_loader = splitter.get_loaders(train_batch_size=2, val_batch_size=1, num_workers=0)

        train_batches = list(train_loader)
        val_batches = list(val_loader)

        self.assertEqual(len(train_batches), 4)
        self.assertEqual(len(val_batches), 2)


if __name__ == '__main__':
    unittest.main()
profile
미친 취준생

0개의 댓글