Attention unet을 이용한 segmentation 진행한 거를 정리해본다.
데이터는 kaggle 데이터를 사용했으며 리더보드에 다른 사람들의 코드들도 많이 있다.
기본 레이어로도 어느정도 성능이 나온다.
데이터셋에서 0~1 사이의 정규화만 시켜도 훈련이 된다.
normalization 이나 aug를 추가하면 물론 더욱 잘 된다.
확실히 GAN이 훈련시키기 어렵구나 하였다..
[dataset]
class brain_dataset(data.Dataset):
def __init__(self):
super(brain_dataset, self).__init__()
ROOT_PATH = './kaggle_3m/'
self.mask_file = glob.glob(ROOT_PATH + '**/*mask*', recursive=True)
self.image_file =[i.replace('_mask','') for i in self.mask_file]
def __getitem__(self, index):
mask = self.mask_file[index]
image = self.image_file[index]
data = cv2.imread(image)
data = data/255.
data = data.transpose(2,0,1)
label = cv2.imread(mask,0)
label = np.expand_dims(label, axis=-1)
label = label/255.
label = label.transpose(2,0,1)
return torch.from_numpy(data).float(), torch.from_numpy(label).float()
def __len__(self):
return len(self.image_file)
[결과]
파이토치에서도 brain segmentation model을 지원한다.
p_model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
in_channels=3, out_channels=1, init_features=32, pretrained=True).to(device)