pretrained deeplab model을 사용하면, 이미지를 21개 (20 클래스 + 배경)으로 segmentation 할 수 있습니다.
cv2를 함께 사용하여 원하는 부분을 제외하고 blur 시키는 Soft Focus를 구현해 봅니다
import libraries & load image
import cv2
import numpy as np
import os
import tarfile
import urllib
from matplotlib import pyplot as plt
import tensorflow as tf
class DeepLabModel(object):
INPUT_TENSOR_NAME = 'ImageTensor:0'
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
INPUT_SIZE = 513
FROZEN_GRAPH_NAME = 'frozen_inference_graph'
# __init__()에서 모델 구조를 직접 구현하는 대신, tar file에서 읽어들인 그래프구조 graph_def를
# tf.compat.v1.import_graph_def를 통해 불러들여 활용하게 됩니다.
def __init__(self, tarball_path):
self.graph = tf.Graph()
graph_def = None
tar_file = tarfile.open(tarball_path)
for tar_info in tar_file.getmembers():
if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
file_handle = tar_file.extractfile(tar_info)
graph_def = tf.compat.v1.GraphDef.FromString(file_handle.read())
break
tar_file.close()
with self.graph.as_default():
tf.compat.v1.import_graph_def(graph_def, name='')
self.sess = tf.compat.v1.Session(graph=self.graph)
# 이미지를 전처리하여 Tensorflow 입력으로 사용 가능한 shape의 Numpy Array로 변환합니다.
def preprocess(self, img_orig):
height, width = img_orig.shape[:2]
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
resized_image = cv2.resize(img_orig, target_size)
resized_rgb = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
img_input = resized_rgb
return img_input
def run(self, image):
img_input = self.preprocess(image)
# Tensorflow V1에서는 model(input) 방식이 아니라 sess.run(feed_dict={input...}) 방식을 활용합니다.
batch_seg_map = self.sess.run(
self.OUTPUT_TENSOR_NAME,
feed_dict={self.INPUT_TENSOR_NAME: [img_input]})
seg_map = batch_seg_map[0]
return cv2.cvtColor(img_input, cv2.COLOR_RGB2BGR), seg_map
DeepLabModel 객체를 생성하고, 훈련이 완료된 모델을 다운로드하여 불러옵니다
# define model and download & load pretrained weight
_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'
model_dir = os.getenv('HOME')+'/aiffel/human_segmentation/models'
tf.io.gfile.makedirs(model_dir)
print ('temp directory:', model_dir)
download_path = os.path.join(model_dir, 'deeplab_model.tar.gz')
if not os.path.exists(download_path):
urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
download_path)
MODEL = DeepLabModel(download_path)
print('model loaded successfully!')
LABEL_NAMES = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tv",
]
LABEL_ENCODING = {label: idx for idx, label in enumerate(LABEL_NAMES)}
def basic_shallow_focus(img_path, label_str, blur_level=13):
label = LABEL_ENCODING[label_str]
img_orig = cv2.imread(img_path)
img_resized, seg_map = MODEL.run(img_orig)
seg_map = np.where(seg_map == label, label, 0) # 예측 중 label만 추출
img_mask = seg_map * (
255 / (seg_map.max() + 1e-10)
) # 255 normalization, divisionerror 방지
img_mask = img_mask.astype(np.uint8)
img_mask_up = cv2.resize(
img_mask, img_orig.shape[:2][::-1], interpolation=cv2.INTER_LINEAR
)
_, img_mask_up = cv2.threshold(img_mask_up, 128, 255, cv2.THRESH_BINARY)
img_mask_color = cv2.cvtColor(img_mask_up, cv2.COLOR_GRAY2BGR)
img_orig_blur = cv2.blur(
img_orig, (blur_level, blur_level)
) # blurring kernel size를 뜻합니다.
img_bg_mask = cv2.bitwise_not(img_mask_color)
img_bg_blur = cv2.bitwise_and(img_orig_blur, img_bg_mask)
img_concat = np.where(img_mask_color == 255, img_orig, img_bg_blur)
plt.imshow(cv2.cvtColor(img_concat, cv2.COLOR_BGR2RGB))
plt.show()