음악 분류 딥러닝을 만들자(6) - 데이터 라벨링

응큼한포도·2024년 6월 20일
1
post-thumbnail

설계

기본적으로 데이터 관리는 mysql 같은 데이터 베이스를 이용해서 하는 것이 좋다. 성능과 관리에서 아주 뛰어나기 때문에 서비스를 만든다면 무조건 데이터 베이스에 데이터를 관리하자.

하지만 이번엔 간단한 실습 프로젝트이고 데이터의 개수가 작기 때문에 라벨링한 데이터를 pandas로 관리하겠다.

만든 데이터를 라벨링하자

요즘엔 라벨링을 안하는 방법도 있다고 하지만 이번엔 라벨링을 하겠다.

우선 저번시간에 모든 가수를 전처리하고 한 폴더 안에 넣어주자.

import os
import pandas as pd
from pathlib import Path
from preprocessing.file_system_helper import FileSystemHelper

class MelSpectrogramLabeler:
    def __init__(self, base_dir):
        """
        Class for labeling Mel spectrogram files within a directory structure.

        Args:
            base_dir (str or Path): Base directory path containing Mel spectrogram data.

        Attributes:
            base_dir (Path): Path to the base directory containing Mel spectrogram data.
            data (list of dict): List to store labeled spectrogram data in the format {'file_path': str, 'label': str}.
        """
        self.base_dir = Path(base_dir)
        self.data = []

    def label_spectrogram(self, mel_dir_name, extension):
        """
        Label Mel spectrogram files within artist directories in the base directory.

        Args:
            mel_dir_name (str): Name of the directory containing Mel spectrogram files.
            extension (str): File extension of Mel spectrogram files.

        Returns:
            None
        """
        for artist_dir in self.base_dir.iterdir():
            if not artist_dir.is_dir():
                continue

            artist_name = artist_dir.name
            mel_dir = artist_dir.joinpath(mel_dir_name)

            if mel_dir.exists() and mel_dir.is_dir():
                mel_files = FileSystemHelper.get_files_by_extension(mel_dir, extension)
                for mel_file in mel_files:
                    self.data.append({
                        'file_path': str(mel_file),
                        'label': artist_name
                    })

    def save_to_dataframe(self):
        """
        Convert labeled spectrogram data to a pandas DataFrame.

        Returns:
            pd.DataFrame: DataFrame containing 'file_path' and 'label' columns.
        """
        return pd.DataFrame(self.data)

    def save_to_csv(self, output_path):
        """
        Save labeled spectrogram data to a CSV file.

        Args:
            output_path (str or Path): Path to save the CSV file.

        Returns:
            None
        """
        df = self.save_to_dataframe()
        df.to_csv(output_path, index=True)

fileSystemHelper 확장

이 클래스안에 파일 경로 확인하는 매서드가 있는 데 저번에 만든 fileSystemHelper를 확장해주자.

import os

class FileSystemHelper:
    """
    A helper class for filesystem operations related to WAV files.

    Methods
    -------
    make_dir(path):
        Creates a directory if it doesn't already exist.

    get_wav_files(folder_path):
        Retrieves a list of paths to WAV files within a specified folder.

    """

    @staticmethod
    def ensure_directory_exists(path):
        """
        Ensures that a directory exists at the given path. Creates it if necessary.

        Parameters
        ----------
        path : str
            The path of the directory to ensure existence.
        """
        if not os.path.exists(path):
            os.makedirs(path)

    @staticmethod
    def get_files_by_extension(folder_path, extension):
        """
        Retrieves a list of paths to files with a specific extension within a folder.

        Parameters
        ----------
        folder_path : str
            The path of the folder where files are located.
        extension : str
            The extension of files to retrieve (e.g., 'wav', 'png').

        Returns
        -------
        list
            A list of paths to files with the specified extension (absolute paths).
        """
        return [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith(extension)
                and os.path.isfile(os.path.join(folder_path, file))]

테스트 코드

import unittest
import tempfile
import pandas as pd
from pathlib import Path
from preprocessing.mel_spectrogram_labeler import MelSpectrogramLabeler

class TestMelSpectrogramLabeler(unittest.TestCase):

    def setUp(self):
        # 임시 디렉터리 생성
        self.temp_dir = tempfile.TemporaryDirectory()

        # 가상의 기본 디렉터리 생성
        self.base_dir = Path(self.temp_dir.name) / 'mock_base'
        self.base_dir.mkdir()

        # 가상의 아티스트 디렉터리와 mel 파일 생성
        artist1_dir = self.base_dir / 'artist1'
        artist1_dir.mkdir()
        mel1_dir = artist1_dir / 'mel'
        mel1_dir.mkdir()
        (mel1_dir / 'spectrogram1.png').touch()

        artist2_dir = self.base_dir / 'artist2'
        artist2_dir.mkdir()
        mel2_dir = artist2_dir / 'mel'
        mel2_dir.mkdir()
        (mel2_dir / 'spectrogram2.png').touch()

        # 테스트할 MelSpectrogramLabeler 객체 생성
        self.labeler = MelSpectrogramLabeler(self.base_dir)

        # 예상 경로와 레이블 정의
        self.expected_paths = [
            str(self.base_dir / 'artist1' / 'mel' / 'spectrogram1.png'),
            str(self.base_dir / 'artist2' / 'mel' / 'spectrogram2.png')
        ]
        self.expected_labels = ['artist1', 'artist2']

    def tearDown(self):
        # 임시 디렉터리 정리
        self.temp_dir.cleanup()

    def test_label_spectrogram(self):
        # label_spectrogram 메서드 호출
        self.labeler.label_spectrogram('mel', 'png')

        # 데이터 검증
        self.assertEqual(len(self.labeler.data), 2)
        for entry in self.labeler.data:
            self.assertIn(entry['file_path'], self.expected_paths)
            self.assertIn(entry['label'], self.expected_labels)

    def test_save_to_dataframe(self):
        # label_spectrogram 메서드 호출
        self.labeler.label_spectrogram('mel', 'png')

        # save_to_dataframe 메서드 호출 및 데이터프레임 반환 검증
        df = self.labeler.save_to_dataframe()
        self.assertIsInstance(df, pd.DataFrame)
        self.assertEqual(len(df), 2)
        expected_columns = ['file_path', 'label']
        self.assertListEqual(list(df.columns), expected_columns)

        # label_spectrogram 메서드에서 생성된 데이터와 일치하는 값을 검증
        for index, row in df.iterrows():
            self.assertIn(row['file_path'], self.expected_paths)
            self.assertIn(row['label'], self.expected_labels)

    def test_save_to_csv(self):
        # label_spectrogram 메서드 호출
        self.labeler.label_spectrogram('mel', 'png')

        # 임시 CSV 파일 경로 설정
        temp_csv_path = self.temp_dir.name + '/output.csv'

        # save_to_csv 메서드 호출
        self.labeler.save_to_csv(temp_csv_path)

        # CSV 파일 존재 여부 확인
        self.assertTrue(Path(temp_csv_path).exists())

        # CSV 파일 내용 검증
        df = pd.read_csv(temp_csv_path)
        self.assertIsInstance(df, pd.DataFrame)
        self.assertEqual(len(df), 2)
        expected_columns = ['file_path', 'label']
        self.assertListEqual(list(df.columns), expected_columns)

        # label_spectrogram 메서드에서 생성된 데이터와 일치하는 값을 검증
        for index, row in df.iterrows():
            self.assertIn(row['file_path'], self.expected_paths)
            self.assertIn(row['label'], self.expected_labels)

if __name__ == '__main__':
    unittest.main()

실행 코드

from pathlib import Path
from preprocessing.mel_spectrogram_labeler import MelSpectrogramLabeler

def main():
    # Define the base directory where your Mel spectrogram data is located
    base_directory = '/Users/seong-gyeongjun/Downloads/vocal artist'

    # Create an instance of MelSpectrogramLabeler
    labeler = MelSpectrogramLabeler(base_directory)

    # Define the directory name containing Mel spectrogram files and their extension
    mel_dir_name = 'mel'
    extension = 'png'

    # Label the spectrograms
    labeler.label_spectrogram(mel_dir_name, extension)

    # Save the labeled data to a CSV file
    output_csv_path = '/Users/seong-gyeongjun/Downloads/vocal_label/output.csv'
    labeler.save_to_csv(output_csv_path)

    # Optionally, you can also save the labeled data to a DataFrame
    labeled_df = labeler.save_to_dataframe()

    # Print or use the labeled DataFrame as needed
    print(labeled_df.head())

if __name__ == "__main__":
    main()

데이터

위와 같이 라벨링 데이터를 얻을 수 있다

profile
미친 취준생

0개의 댓글