저번 글에서는 PyTorch로 모델을 불러와 이미지를 예측 했다.
VGG16의 모델이 가장 좋은 예측결과를 냈으며
특정 이미지에서는 다른 모델이 VGG 모델보다 좋은 결과를 내는 경우도 있었다.
Keras에서도 어떻게 작동을 하는지 확인하자.
Keras의 applications 모듈을 사용하여 사전 학습된 아래 모델들을 로드
1. VGG16
2. ResNet
3. Inception v3
4. MobileNet v2
5. DenseNet201
6. Mobile NASNet
7. EfficientNetB7
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.vgg16 import preprocess_input, decode_predictions
import numpy as np
model = VGG16(weights='imagenet')
model.summary()
!wget https://moderncomputervision.s3.eu-west-2.amazonaws.com/imagesDLCV.zip
!unzip imagesDLCV.zip
!rm rf images/class1/.DS_Store
import cv2
from os import listdir
from os.path import isfile, join
# Get images located in ./images folder
mypath = "./images/class1/"
file_names = [f for f in listdir(mypath) if isfile(join(mypath, f))]
file_names
import matplotlib.pyplot as plt
fig=plt.figure(figsize=(16,16))
# Loop through images run them through our classifer
for (i,file) in enumerate(file_names):
img = image.load_img(mypath+file, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
#load image using opencv
img2 = cv2.imread(mypath+file)
#imageL = cv2.resize(img2, None, fx=.5, fy=.5, interpolation = cv2.INTER_CUBIC)
# Get Predictions
preds = model.predict(x)
preditions = decode_predictions(preds, top=3)[0]
print(preditions)
# Plot image
sub = fig.add_subplot(len(file_names),1, i+1)
sub.set_title(f'Predicted {str(preditions[0][1])}')
plt.axis('off')
plt.imshow(cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
plt.show()
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as np
model = ResNet50(weights='imagenet')
model.summary()
모델을 통해 예측하는 코드는 1번의 VGG16 코드와 동일하므로 생략
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.inception_v3 import preprocess_input
import numpy as np
model = InceptionV3(weights='imagenet')
model.summary()
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
import numpy as np
model = MobileNetV2(weights='imagenet')
model.summary()
from tensorflow.keras.applications.densenet import DenseNet201
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.densenet import preprocess_input
import numpy as np
model = DenseNet201(weights='imagenet')
model.summary()
from tensorflow.keras.applications.nasnet import NASNetMobile
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.nasnet import preprocess_input
import numpy as np
model = NASNetMobile(weights='imagenet')
model.summary()
from tensorflow.keras.applications.efficientnet import EfficientNetB7
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.efficientnet import preprocess_input
import numpy as np
model = EfficientNetB7(weights='imagenet')
model.summary()
Keras 모델이 대체로 PyTorch 모델보다 더 나은 성능을 보여준다.
EfficientNet B7 모델이 모든 테스트 이미지를 정확하게 예측했다.
VGG16, ResNet50 등 다른 모델들도 대체로 좋은 성능을 보였지만, 일부 이미지는 예측하지 못했다.
Limousine - 0.8712
Basketball - 0.9486
Collie - 0.7440
German Shepherd - 0.9481
Christmas Stocking - 0.9472
Doormat - 0.9480
Burrito - 0.9551
Spider Web - 0.8747
Beer Glass - 0.0827 (실패 예측)
Limousine - 0.8414
Basketball - 0.8935
Collie - 0.7020
German Shepherd - 0.7762
Christmas Stocking - 0.8364
Doormat - 0.7075
Burrito - 0.8121
Spider Web - 0.7373
Beer Glass - 0.7442
Mobile NASNet은 대부분의 이미지에 대해 더 높은 확률로 정확한 예측을 제공했고,
EfficientNetB7은 일부 이미지에서 더 낮은 확률로 예측했지만, 특정 이미지(예: Beer Glass)에서 더 높은 확률로 예측했다.
그렇다면 NASNet의 모델 성능이 더 좋지 않은가?
하지만 결론에는 EfficientNetB7가 가장 좋다고 명시했다.
EfficientNetB7이 가장 좋은 모델이라고 설명된 이유는 더
많은 이미지에서 전반적으로 높은 성능을 보였기 때문이다.
Do you have the above dataset link?
The above code for fetching the dataset is not working so if you can provide a drive a link to access data