[AI] Load image data with CustomDataLoader from json

JAsmine_log·2024년 8월 28일
0

Load image data with CustomDataLoader from json

Json 이나 텍스트 데이터셋으로부터 image를 부르기 위한 방법이다. 여기서는 json을 활용해 보겠다!

CustomDataLoader Class

class MultiJsonDataset(Dataset):
    def __init__(self, json_dir, root_dir, transform=None):
        """
        Args:
            json_dir (string): JSON 파일들이 위치한 디렉토리 경로
            root_dir (string): 이미지들이 위치한 디렉토리 경로
            transform (callable, optional): 이미지에 적용할 변환(transforms)
        """
        self.json_files = glob(os.path.join(json_dir, '*.json'))
        self.root_dir = root_dir
        self.transform = transform
        self.data_list = []

        for json_file in self.json_files:
            with open(json_file, 'r') as f:
                data_info = json.load(f)
            img_name = data_info['meta']['img']
            labels = data_info['series']['labels']
            data = data_info['data']['data']

            for i in range(data_info['data']['data_len']):
                item = {
                    'img_name': img_name,
                    'label': torch.tensor(data[labels[0]][i])
                }
                self.data_list.append(item)

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.data_list[idx]['img_name'])
        image = Image.open(img_path).convert('RGB')

        label = self.data_list[idx]['label']

        if self.transform:
            image = self.transform(image)

        return image, label

main process


# JSON 파일이 있는 디렉토리 경로 및 이미지 디렉토리 설정
json_dir = 'path/to/your/json_directory'
root_dir = 'path/to/your/images_directory'

# 이미지에 적용할 변환(예시)
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

# MultiJsonDataset 인스턴스 생성
dataset = MultiJsonDataset(json_dir=json_dir, root_dir=root_dir, transform=transform)

# DataLoader 설정
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=4)

# 데이터 로드 및 학습 예시
for images, labels in data_loader:
    # images는 배치 크기만큼의 이미지 텐서, labels는 해당 이미지들의 라벨 텐서
    print(images.shape, labels)

Image Visualization


# 이미지 및 라벨 시각화
for images, labels in data_loader:
    # 각 배치에서 이미지를 반복적으로 가져와서 시각화
    for i in range(images.size(0)):
        image = images[i].permute(1, 2, 0).numpy()  # 이미지 텐서를 numpy로 변환
        label = labels[i].item()  # 텐서에서 라벨 값을 가져옴

        plt.imshow(image)
        plt.title(f"Label: {label}")
        plt.axis('off')
        plt.show()
        
profile
Everyday Research & Development

0개의 댓글