Ex8-Shallow_Focus 구현

안 형준·2021년 8월 17일
0

Aiffel/Archives

목록 보기
9/12

pretrained deeplab model을 사용하면, 이미지를 21개 (20 클래스 + 배경)으로 segmentation 할 수 있습니다.
cv2를 함께 사용하여 원하는 부분을 제외하고 blur 시키는 Soft Focus를 구현해 봅니다

load model & pretrained weights

import libraries & load image

import cv2
import numpy as np
import os
import tarfile
import urllib

from matplotlib import pyplot as plt
import tensorflow as tf
class DeepLabModel(object):
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    INPUT_SIZE = 513
    FROZEN_GRAPH_NAME = 'frozen_inference_graph'

    # __init__()에서 모델 구조를 직접 구현하는 대신, tar file에서 읽어들인 그래프구조 graph_def를 
    # tf.compat.v1.import_graph_def를 통해 불러들여 활용하게 됩니다. 
    def __init__(self, tarball_path):
        self.graph = tf.Graph()
        graph_def = None
        tar_file = tarfile.open(tarball_path)
        for tar_info in tar_file.getmembers():
            if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                file_handle = tar_file.extractfile(tar_info)
                graph_def = tf.compat.v1.GraphDef.FromString(file_handle.read())
                break
        tar_file.close()

        with self.graph.as_default():
    	    tf.compat.v1.import_graph_def(graph_def, name='')

        self.sess = tf.compat.v1.Session(graph=self.graph)

    # 이미지를 전처리하여 Tensorflow 입력으로 사용 가능한 shape의 Numpy Array로 변환합니다.
    def preprocess(self, img_orig):
        height, width = img_orig.shape[:2]
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        resized_image = cv2.resize(img_orig, target_size)
        resized_rgb = cv2.cvtColor(resized_image, cv2.COLOR_BGR2RGB)
        img_input = resized_rgb
        return img_input
        
    def run(self, image):
        img_input = self.preprocess(image)

        # Tensorflow V1에서는 model(input) 방식이 아니라 sess.run(feed_dict={input...}) 방식을 활용합니다.
        batch_seg_map = self.sess.run(
            self.OUTPUT_TENSOR_NAME,
            feed_dict={self.INPUT_TENSOR_NAME: [img_input]})

        seg_map = batch_seg_map[0]
        return cv2.cvtColor(img_input, cv2.COLOR_RGB2BGR), seg_map

DeepLabModel 객체를 생성하고, 훈련이 완료된 모델을 다운로드하여 불러옵니다

# define model and download & load pretrained weight
_DOWNLOAD_URL_PREFIX = 'http://download.tensorflow.org/models/'

model_dir = os.getenv('HOME')+'/aiffel/human_segmentation/models'
tf.io.gfile.makedirs(model_dir)

print ('temp directory:', model_dir)

download_path = os.path.join(model_dir, 'deeplab_model.tar.gz')
if not os.path.exists(download_path):
    urllib.request.urlretrieve(_DOWNLOAD_URL_PREFIX + 'deeplabv3_mnv2_pascal_train_aug_2018_01_29.tar.gz',
                   download_path)

MODEL = DeepLabModel(download_path)
print('model loaded successfully!')

Shallow Focus

LABEL_NAMES = [
    "background",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "pottedplant",
    "sheep",
    "sofa",
    "train",
    "tv",
]

LABEL_ENCODING = {label: idx for idx, label in enumerate(LABEL_NAMES)}


def basic_shallow_focus(img_path, label_str, blur_level=13):
    label = LABEL_ENCODING[label_str]

    img_orig = cv2.imread(img_path)
    img_resized, seg_map = MODEL.run(img_orig)

    seg_map = np.where(seg_map == label, label, 0)  # 예측 중 label만 추출
    img_mask = seg_map * (
        255 / (seg_map.max() + 1e-10)
    )  # 255 normalization, divisionerror 방지
    img_mask = img_mask.astype(np.uint8)

    img_mask_up = cv2.resize(
        img_mask, img_orig.shape[:2][::-1], interpolation=cv2.INTER_LINEAR
    )
    _, img_mask_up = cv2.threshold(img_mask_up, 128, 255, cv2.THRESH_BINARY)
    img_mask_color = cv2.cvtColor(img_mask_up, cv2.COLOR_GRAY2BGR)

    img_orig_blur = cv2.blur(
        img_orig, (blur_level, blur_level)
    )  # blurring  kernel size를 뜻합니다.

    img_bg_mask = cv2.bitwise_not(img_mask_color)
    img_bg_blur = cv2.bitwise_and(img_orig_blur, img_bg_mask)

    img_concat = np.where(img_mask_color == 255, img_orig, img_bg_blur)

    plt.imshow(cv2.cvtColor(img_concat, cv2.COLOR_BGR2RGB))
    plt.show()

결과

profile
물리학과 졸업/ 인공지능 개발자로의 한 걸음

0개의 댓글