[딥러닝] 전이학습 (Transfer Learning)

Seung Joo·2021년 8월 24일
0

전이학습(Transfer Learning)

기존에 원래의 프로젝트의 목적과 다른 데이터로 학습된 네트워크를 재사용 가능하도록 하는 라이브러리이다. 이를 통해 수천 시간의 GPU로 학습된 모델을 다운받아 나의 프로젝트에 활용할 수 있다. 학습되었다는 것은 가중치(Weight)와 편향(bias)을 포함하여 학습된 모델의 일부를 재사용하기에 전이 학습이라고 표현한다.
라이브러리에서 모델을 가져와 새로운 층을 쌓아올려서 사용할수도 있다.

전이학습(resnet)을 사용한 Moutain,Forest dataset 분류

import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
import matplotlib.pyplot as plt
from zipfile import ZipFile
import os

with ZipFile('mountainForest.zip', 'r') as zip:
    data = zip.extractall(os.getcwd())
    
TRAIN_PATH = '/content/mountainForest/train'
VAL_PATH = '/content/mountainForest/validation'

train = image_dataset_from_directory(TRAIN_PATH, label_mode='int', class_names=['forest', 'mountain'])
val = image_dataset_from_directory(VAL_PATH, label_mode='int', class_names=['forest', 'mountain'])
train : Found 533 files belonging to 2 classes.
val : Found 195 files belonging to 2 classes.
class_names = train.class_names

plt.figure(figsize=(10, 10))
for images, labels in train.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

import numpy as np
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model

resnet = ResNet50(weights='imagenet', include_top=False)

# 학습 동결하여 가중치 업데이트가 일어나지 않도록 함.
for layer in resnet.layers:
    layer.trainable = False

# 모델 정의
x = resnet.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(1, activation='sigmoid')(x) # 출력층을 설계합니다.
model = Model(resnet.input, predictions)

# 모델 컴파일
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

model.fit(train, batch_size=128, epochs=5)
Epoch 1/5
17/17 [==============================] - 4s 84ms/step - loss: 0.1711 - accuracy: 0.9306
Epoch 2/5
17/17 [==============================] - 2s 84ms/step - loss: 0.0497 - accuracy: 0.9850
Epoch 3/5
17/17 [==============================] - 2s 81ms/step - loss: 0.0163 - accuracy: 0.9962
Epoch 4/5
17/17 [==============================] - 2s 84ms/step - loss: 0.0037 - accuracy: 0.9981
Epoch 5/5
17/17 [==============================] - 2s 84ms/step - loss: 3.0132e-04 - accuracy: 1.0000
model.evaluate(val)
7/7 [==============================] - 1s 70ms/step - loss: 0.0627 - accuracy: 0.9846
profile
조금씩 천천히

0개의 댓글