[이미지처리바이블] 5.3장 이미지 분할: Segmentation

Changh2·2025년 6월 1일
0

이미지처리바이블

목록 보기
7/7

[이미지 처리 바이블] 교재 5장을 기반으로 작성되었습니다.


5.3 이미지 분할: Segmentation

> 이미지를 구성하는 각 픽셀이 어떤 객체에 속하는지 분류하는 과정

5.3.1 FCN

FCN 등장 이전의 알고리즘

색상 히스토그램

> 이미지 내 특정 색상의 빈도를 그래픽으로 나타냄
> 그로인해 다양한 정보를 얻을 수 있음

특정 색상 c가 이미지 내에서 나타나는 빈도를 측정하는 수식:

H(c) = 색상 c의 히스토그램 빈도,
I(x,y) = 좌표 (x,y)에 위치한 픽셀의 색상 값,
W = 이미지의 너비, H = 이미지의 높이


임계 값 처리

> 이미지 내의 픽셀을 특정 기준에 따라 분류하고 단순화하는 과정
> 디지털 이미지를 이진화하여 객체를 감지하거나 분할할 때 사용하는 알고리즘


OpenCV를 활용한 임계 값 처리 실습

""" 📌 이미지 불러오기 """
import cv2
import matplotlib.pyplot as plt

!wget https://raw.githubusercontent.com/Lilcob/test_colab/main/three%20young%20man.jpg

new_image_color = '/content/three young man.jpg'
new_image_color = cv2.imread(new_image_color)
image_gray = cv2.cvtColor(new_image_color, cv2.COLOR_BGR2GRAY)

""" 📌 임계 값 설정 """
threshold_value = 128
_, thresholded_image = cv2.threshold(image_gray, threshold_value, 255, cv2.THRESH_BINARY)

""" 📌 결과 이미지 시각화 """
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image_gray, cmap='gray')
plt.title('Original Grayscale Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(thresholded_image, cmap='gray')
plt.title('Binary Image after Thresholding')
plt.axis('off')


FCN

> Fully Convolutional Networks
> 전통적인 CNN은 분류에는 탁월하지만, 픽셀 단위의 정밀한 영역분할에는 한계가 있었음
> 이를 극복하기 위해 등장한 것이 FCN
> Dense 층을 사용하지 않고 모든 층을 Conv 층으로 구성하여 공간 정보를 보존

FCN 구조

> Dense 층의 제거는 네트워크가 전체 이미지를 통합적으로 분석하는 대신, 각각의 픽셀을 독립적으로 평가하고, 각 픽셀이 속한 클래스를 예측할 수 있게 함
> FCN에서는 업샘플링을 통해 감소된 이미지 해상도를 다시 원본 사이즈로 복원함
> 그 뒤 Skip 연결을 사용해 네트워크의 초기층에서 얻은 정보를 후반부의 업샘플링 층과 결합함

[업샘플링 방법론들]
1. 최근접 이웃 (nearest neighbor)
2. 선형 보간 (linear interpolation)
3. 전치 합성곱 (transposed convolution)

히트맵: 각 픽셀에 대한 확률을 나타내는 맵


텐서플로를 활용한 FCN 실습

""" 📌 라이브러리 설정 """
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import numpy as np
keras.utils.set_random_seed(27)
tf.random.set_seed(27) 

AUTOTUNE = tf.data.AUTOTUNE   

""" 📌 하이퍼파라미터 설정 """
NUM_CLASSES = 4
INPUT_HEIGHT = 224
INPUT_WIDTH = 224
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 20
BATCH_SIZE = 32
MIXED_PRECISION = True
SHUFFLE = True
if MIXED_PRECISION:
    policy = keras.mixed_precision.Policy("mixed_float16")
    keras.mixed_precision.set_global_policy(policy)

""" 📌 데이터셋 로드 """
(train_ds, valid_ds, test_ds) = tfds.load(  
    "oxford_iiit_pet",                           
    split=["train[:85%]", "train[85%:]", "test"], 
    batch_size=BATCH_SIZE, 
    shuffle_files=SHUFFLE,  
)

""" 📌 분할, 사이즈 조정 함수 선언 """ 
def unpack_resize_data(section):
    image = section["image"]
    segmentation_mask = section["segmentation_mask"]

    resize_layer = keras.layers.Resizing(INPUT_HEIGHT, INPUT_WIDTH)

    image = resize_layer(image)
    segmentation_mask = resize_layer(segmentation_mask)

    return image, segmentation_mask

train_ds = train_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)
valid_ds = valid_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE) 
test_ds = test_ds.map(unpack_resize_data, num_parallel_calls=AUTOTUNE)



""" 📌 테스트 세트에서 랜덤하게 이미지 추출, 시각화 """
images, masks = next(iter(test_ds))
random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)

test_image = images[random_idx].numpy().astype("float")
test_mask = masks[random_idx].numpy().astype("float")

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

ax[0].set_title("Image")
ax[0].imshow(test_image / 255.0)

ax[1].set_title("Image with segmentation mask overlay")
ax[1].imshow(test_image / 255.0)
ax[1].imshow(test_mask,cmap="inferno",alpha=0.6,)
plt.show()

""" 📌 데이터 전처리, 로드 효율성 높임 """
def preprocess_data(image, segmentation_mask): # ①
    image = keras.applications.vgg19.preprocess_input(image)
    return image, segmentation_mask

train_ds = (
    train_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE).shuffle(buffer_size=1024).prefetch(buffer_size=1024)) # ②
valid_ds = (
    valid_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE).shuffle(buffer_size=1024).prefetch(buffer_size=1024))
test_ds = (
    test_ds.map(preprocess_data, num_parallel_calls=AUTOTUNE).shuffle(buffer_size=1024).prefetch(buffer_size=1024)
)

import tensorflow as tf
from tensorflow.keras import mixed_precision

# 이 한 줄로 mixed precision 정책을 float32 전용으로 바꿔버립니다.
mixed_precision.set_global_policy('float32')

""" 📌 모델 작성"""
input_layer = keras.Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, 3))

vgg_model = keras.applications.vgg19.VGG19(include_top=True, weights="imagenet")

fcn_backbone = keras.models.Model(
    inputs=vgg_model.layers[1].input,
    outputs=[
        vgg_model.get_layer(block_name).output
        for block_name in ["block3_pool", "block4_pool", "block5_pool"]
    ],
)

# 모델 미세 조정을 위해 백본 네트워크의 가중치 값 고정
fcn_backbone.trainable = False
x = fcn_backbone(input_layer)

# conv층, 드롭아웃층 추가
units = [4096, 4096] 
dense_convs = []

for filter_idx in range(len(units)):
    dense_conv = keras.layers.Conv2D(
        filters=units[filter_idx], kernel_size=(7, 7) if filter_idx == 0 else (1, 
1),strides=(1, 1), activation="relu", padding="same", use_bias=False, kernel_initializer=tf.constant_initializer(1.0),
    )
    dense_convs.append(dense_conv)
    dropout_layer = keras.layers.Dropout(0.5)
    dense_convs.append(dropout_layer)

dense_convs = keras.Sequential(dense_convs) 
dense_convs.trainable = False 

x[-1] = dense_convs(x[-1])
pool3_output, pool4_output, pool5_output = x

# FCN 모델의 뒷부분은 fcn32s 아키텍처를 구성 (conv층과 업샘플링층을 정의, 연결)
pool5 = keras.layers.Conv2D(filters=NUM_CLASSES, kernel_size=(1, 1), padding="same", strides=(1, 1), activation="relu",)

fcn32s_conv_layer = keras.layers.Conv2D(filters=NUM_CLASSES, kernel_size=(1, 1), activation="softmax", padding="same", strides=(1, 1),)

# 최종 예측 맵을 원본 이미지 사이즈로 복원하는 업샘플링 과정 (32배 업샘플링)
fcn32s_upsampling = keras.layers.UpSampling2D(size=(32, 32),data_format=keras.backend.image_data_format(), interpolation="bilinear",)

# fcn32s의 최종 출력 생성
final_fcn32s_pool = pool5(pool5_output)
final_fcn32s_output = fcn32s_conv_layer(final_fcn32s_pool)
final_fcn32s_output = fcn32s_upsampling(final_fcn32s_output)
fcn32s_model = keras.Model(inputs=input_layer, outputs=final_fcn32s_output)


# 추가로 FCN의 두가지 변형인 FCN-16s, FCN-8s를 구현, 결합
# FCN-16은 더 세밀한 특성을, FCN-8s는 더욱 정밀한 영역 분할 제공
pool4 = keras.layers.Conv2D(filters=NUM_CLASSES,kernel_size=(1, 1),
padding="same",strides=(1, 1),activation="linear",kernel_initializer=keras.
initializers.Zeros(),)(pool4_output)

pool5 = keras.layers.UpSampling2D(size=(2, 2),data_format=keras.backend.image_data_format(),interpolation="bilinear",)(final_fcn32s_pool)

fcn16s_conv_layer = keras.layers.Conv2D(filters=NUM_CLASSES,kernel_size=(1, 1),
activation="softmax",padding="same",strides=(1, 1),)

fcn16s_upsample_layer = keras.layers.UpSampling2D(size=(16, 16),data_format=keras.
backend.image_data_format(),interpolation="bilinear",)

final_fcn16s_pool = keras.layers.Add()([pool4, pool5])
final_fcn16s_output = fcn16s_conv_layer(final_fcn16s_pool)
final_fcn16s_output = fcn16s_upsample_layer(final_fcn16s_output)

fcn16s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn16s_output)

pool3 = keras.layers.Conv2D( filters=NUM_CLASSES, kernel_size=(1, 1), padding="same", strides=(1, 1), activation="linear", kernel_initializer=keras.initializers.Zeros(),)(pool3_output)

intermediate_pool_output = keras.layers.UpSampling2D(size=(2, 2),data_format=keras.
backend.image_data_format(),interpolation="bilinear",)(final_fcn16s_pool)

fcn8s_conv_layer = keras.layers.Conv2D(filters=NUM_CLASSES,kernel_size=(1, 1),
activation="softmax",padding="same",strides=(1, 1),)

fcn8s_upsample_layer = keras.layers.UpSampling2D(size=(8, 8),data_format=keras.
backend.image_data_format(),interpolation="bilinear",)

final_fcn8s_pool = keras.layers.Add()([pool3, intermediate_pool_output])
final_fcn8s_output = fcn8s_conv_layer(final_fcn8s_pool)
final_fcn8s_output = fcn8s_upsample_layer(final_fcn8s_output)

fcn8s_model = keras.models.Model(inputs=input_layer, outputs=final_fcn8s_output)

# VGG 모델의 마지막 두 층 가중치 추출 및 가중치 재구성, conv층에 적용
weights1 = vgg_model.get_layer("fc1").get_weights()[0]
weights2 = vgg_model.get_layer("fc2").get_weights()[0]

weights1 = weights1.reshape(7, 7, 512, 4096)
weights2 = weights2.reshape(1, 1, 4096, 4096)

dense_convs.layers[0].set_weights([weights1])
dense_convs.layers[2].set_weights([weights2])

# 앞서 설정한 하이퍼파라미터와, 다양한 버전의 FCN 모델 학습
fcn32s_optimizer = keras.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

fcn32s_loss = keras.losses.SparseCategoricalCrossentropy()

fcn32s_model.compile(
    optimizer=fcn32s_optimizer,loss=fcn32s_loss,metrics=[keras.metrics.
MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),keras.metrics.
SparseCategoricalAccuracy(),],)
fcn32s_history = fcn32s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)

fcn16s_optimizer = keras.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

fcn16s_loss = keras.losses.SparseCategoricalCrossentropy()
fcn16s_model.compile(optimizer=fcn16s_optimizer,loss=fcn16s_loss,metrics=[keras.
metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),keras.metrics.
SparseCategoricalAccuracy(),],)

fcn16s_history = fcn16s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)

fcn8s_optimizer = keras.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

fcn8s_loss = keras.losses.SparseCategoricalCrossentropy()
fcn8s_model.compile(optimizer=fcn8s_optimizer,loss=fcn8s_loss,metrics=[keras.
metrics.MeanIoU(num_classes=NUM_CLASSES, sparse_y_pred=False),keras.metrics.
SparseCategoricalAccuracy(),],)

fcn8s_history = fcn8s_model.fit(train_ds, epochs=EPOCHS, validation_data=valid_ds)
>>> Epoch 1/20
	98/98 ━━━━━━━━━━━━━━━━━━━━ 106s 807ms/step - loss: 1.1481 - mean_io_u_1: 0.1871 - sparse_categorical_accuracy: 0.4662 - val_loss: 0.9564 - val_mean_io_u_1: 0.2737 - val_sparse_categorical_accuracy: 0.5596
	Epoch 2/20
	98/98 ━━━━━━━━━━━━━━━━━━━━ 55s 563ms/step - loss: 0.8225 - mean_io_u_1: 0.3476 - sparse_categorical_accuracy: 0.6662 - val_loss: 0.7170 - val_mean_io_u_1: 0.4347 - val_sparse_categorical_accuracy: 0.7233
    ...(생략)...
""" 📌 학습 완료된 모델 시각화 """
images, masks = next(iter(test_ds))
random_idx = tf.random.uniform([], minval=0, maxval=BATCH_SIZE, dtype=tf.int32)

test_image = images[random_idx].numpy().astype("float")
test_mask = masks[random_idx].numpy().astype("float")

pred_image = tf.expand_dims(test_image, axis=0)
pred_image = keras.applications.vgg19.preprocess_input(pred_image)

pred_mask_32s = fcn32s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_32s = np.argmax(pred_mask_32s, axis=-1)
pred_mask_32s = pred_mask_32s[0, ...]

pred_mask_16s = fcn16s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_16s = np.argmax(pred_mask_16s, axis=-1)
pred_mask_16s = pred_mask_16s[0, ...]

pred_mask_8s = fcn8s_model.predict(pred_image, verbose=0).astype("float")
pred_mask_8s = np.argmax(pred_mask_8s, axis=-1)
pred_mask_8s = pred_mask_8s[0, ...]

fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(15, 8))

fig.delaxes(ax[0, 2])

ax[0, 0].set_title("Image")
ax[0, 0].imshow(test_image / 255.0)

ax[0, 1].set_title("Image with ground truth overlay")
ax[0, 1].imshow(test_image / 255.0)
ax[0, 1].imshow( test_mask,cmap="inferno",alpha=0.6,)

ax[1, 0].set_title("Image with FCN-32S mask overlay")
ax[1, 0].imshow(test_image / 255.0)
ax[1, 0].imshow(pred_mask_32s, cmap="inferno", alpha=0.6)

ax[1, 1].set_title("Image with FCN-16S mask overlay")
ax[1, 1].imshow(test_image / 255.0)
ax[1, 1].imshow(pred_mask_16s, cmap="inferno", alpha=0.6)

ax[1, 2].set_title("Image with FCN-8S mask overlay")
ax[1, 2].imshow(test_image / 255.0)
ax[1, 2].imshow(pred_mask_8s, cmap="inferno", alpha=0.6)

plt.show()

>> FCN-32와 비교했을때, FCN-8S가 비교적 더 정확한 결과를 보여줌
>> 영역 분할 과정에서 더 많은 중간 층의 정보를 통합하여 사용하기 때문


5.3.2 U-Net

> 의료 영상 처리와 같은 고해상도 이미지 영역 분할 작업에 특히 뛰어난 성능을 보이는 아키텍처
> 구조가 U 형태를 닮았는데, "수축 경로"와 "확장 경로"의 두 가지 구성 요소를 가짐


수축 경로

> 네트워크의 초반에 위치하며, 전통적인 CNN 구조와 유사함
> 이미지로부터 중요한 특징을 추출하는 것이 목적
> 스테이지를 나누어 네트워크가 점점 더 깊은 특징을 학습할 수 있도록 설계됨

  1. 초기 단계
    - 이미지의 가장 기본적인 특징 감지 (간단한 패턴, 에지, 색상 변화)
  2. 중간 단계
    - 좀 더 복잡한 요소 감지 (텍스처, 패턴의 조합, 일부 형태나 구조)
  3. 깊은 단계
    - 고수준 특징 처리 (객체의 큰 형태, 전체적인 구조)

확장 경로

> 네트워크의 후반에 위치하며, 수축 경로에서 얻은 특징 맵을 다시 원래 사이즈로 복원하는 역할
> 업샘플링 층, 합성곱 층로 구성되고, 스킵 연결이 중요한 역할을 함

  • 업샘플링 층
    - 이미지의 차원을 점차적으로 확대해, 수축 경로에서의 다운샘플링 과정으로 인해 손실된 상세 정보를 복구함
    - 업샘플링은 보통 transpose convolution이나 최근접 이웃 방법을 사용해 수행됨
  • 합성곱 층
    - 업샘플링된 특징맵이 이후에 거치는 레이어.
    - 업샘플링으로 확대된 이미지에서 부드러운 특징 맵을 생성하고, 세부적인 부분을 세밀하게 다듬는 역할
  • 스킵 연결
    - 수축 경로의 각 스테이지에서 얻은 특징 맵을 확장 경로의 해당 스테이지와 결합함
    - 확장 경로에서 생성된 특징 맵에, 수축 경로에서 추출된 상세한 위치 정보를 추가해주어 모델이 이미지의 구조를 더 잘 이해할 수 있도록 함

  1. 초기 단계
    - 수축 겅로의 마지막 합성곱 층의 출력을 받아 업샘플링을 시작함.
    - 업샘플링된 특징 맵은 수축 경로의 마지막 합성곱 층의 출력과 스킵 연결을 통해 결합됨
  2. 중간 단계
    - 업샘플링과 스킵 연결을 계속해서 이어감. 각 스테이지에 대응되는 출력과 결합, 스킵연결 됨
  3. 최종 단계
    - 최종 업샘플링과 결합을 이루어 이미지를 원래의 해상도로 복원하며, 수축 경로의 첫번째 스테이지에서 얻은 특징 맵과 결합됨

텐서플로를 활용한 U-Net 실습

""" 📌 라이브러리 준비 """
!pip install git+https://github.com/tensorflow/examples.git

import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tensorflow_examples.models.pix2pix import pix2pix
from IPython.display import clear_output

""" 📌 데이터셋 준비 """
dataset, info = tfds.load('oxford_iiit_pet', with_info=True)

def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_mask -= 1
    return input_image, input_mask

def load_image(datapoint):
    input_image = tf.image.resize(datapoint['image'], (128, 128))
    input_mask = tf.image.resize(
        datapoint['segmentation_mask'],
        (128, 128),
        method = tf.image.ResizeMethod.NEAREST_NEIGHBOR,)
    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask

""" 📌 데이터 로드 """
TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE

train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

""" 📌 데이터 증강 """
class Augment(tf.keras.layers.Layer):
    def __init__(self, seed=42):
        super().__init__()
        self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
        self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
  
    def call(self, inputs, labels):
        inputs = self.augment_inputs(inputs)
        labels = self.augment_labels(labels)
        return inputs, labels

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

""" 📌 학습 데이터 시각화 """
def display(display_list):
    plt.figure(figsize=(15, 15))

    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
        plt.axis('off')
    plt.show()

for images, masks in train_batches.take(2):
    sample_image, sample_mask = images[0], masks[0]
    display([sample_image, sample_mask])

""" 📌 모델 정의 """
""" 📌 백본 네트워크 호출 (MobileNetV2) """
base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False) 

layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False

""" 📌 확장경로 (U-Net 모델 완성) """
up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
] 

def UNET_model(output_channels:int):
    inputs = tf.keras.layers.Input(shape=[128, 128, 3])

    skips = down_stack(inputs)
    x = skips[-1] # ③
    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        concat = tf.keras.layers.Concatenate()
        x = concat([x, skip])

    last = tf.keras.layers.Conv2DTranspose(
        filters=output_channels, kernel_size=3, strides=2, padding='same') # ④  
        # 64x64 -> 128x128
    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

""" 📌 컴파일, 모델 시각화 """
OUTPUT_CLASSES = 3

model = UNET_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
tf.keras.utils.plot_model(model, show_shapes=True)

""" 📌 모델 예측 결과 시각화 """
def create_mask(pred_mask):
    pred_mask = tf.math.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0]

def show_predictions(dataset=None, num=1):
    if dataset:
        for image, mask in dataset.take(num):
            pred_mask = model.predict(image)
            display([image[0], mask[0], create_mask(pred_mask)])
    else:
        display([sample_image, sample_mask,
            create_mask(model.predict(sample_image[tf.newaxis, ...]))])
        
class DisplayCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        show_predictions()
        print ('\nSample Prediction after epoch {}\n'.format(epoch+1))

EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

profile
Shoot for the moon! 🔥

0개의 댓글