Tensorflow 기반 모델링을 수행할 때 다음과 같은 API를 사용할 수 있음
Sequential API: 단순하면서 직선적인 모델을 설계할 때 유용
Functional API: 복잡한 네트워크 구조를 가진 모델을 설계할 때 유용
Subclassing API: 객체지향적 형태로 모델을 설계하고자 할 때 유용
각 API는 장단점을 보유하고 있어, 용도에 맞게 활용이 필요함
Sequential() 클래스 내 필요한 레이어를 적재하는 방식으로 사용from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Flatten
# Sequential() 클래스 내 레이어을 순차적으로 적재
model = Sequential([
Flatten(input_shape=(28, 28)), # 입력 레이어
Dense(128, activation='relu'), # 은닉층
Dense(10, activation='softmax') # 출력층
])
Model() 클래스에 시작점과 종료점을 넣어줘야 함from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Flatten
inputs = Input(shape=(28, 28))
flat = Flatten()(inputs) # inputs 레이어를 flat 레이어에 연결
dense = Dense(128, activation='relu')(flat) # flat 레이어를 dense 레이어에 연결
outputs = Dense(10, activation='softmax')(dense) # dense 레이어를 outputs 레이어에 연결
# 시작점/종료점 명시하여 모델 생성
model = Model(inputs=inputs, outputs=outputs)
tf.keras.Layer 또는 tf.keras.Model 클래스를 상속받아 Subclassing 방식으로 모델을 제작할 수 있음import tensorflow as tf
# 커스텀 클래스 정의
class MyModel(tf.keras.Model): # Model 클래스 상속
def __init__(self):
super(MyModel, self).__init__()
# 사용할 레이어 정의
self.conv1 = tf.keras.layers.Conv2D(32, 3, activation='relu')
self.flatten = tf.keras.layers.Flatten()
self.dense1 = tf.keras.layers.Dense(128, activation='relu')
self.dense2 = tf.keras.layers.Dense(10)
# 레이어 연결
def call(self, inputs):
x = self.conv1(inputs)
x = self.flatten(x)
x = self.dense1(x)
return self.dense2(x)
# 모델 생성
model1 = MyModel()
model2 = MyModel() # 동일한 모델을 손쉽게 여러개 찍어낼 수 있음
# 모델 컴파일
model1.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 모델 훈련
model1.fit(train_images, train_labels, epochs=10)
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Input
# Functional API로 입력 생성
inputs = Input(shape=(32,))
x = Dense(64, activation='relu')(inputs)
# Sequential API로 레이어 생성
seq_model = Sequential([
Dense(64, activation='relu'),
Dense(32, activation='relu')
])
# Functional API와 Sequential API 간 연결
x = seq_model(x)
outputs = Dense(10, activation='softmax')(x)
# 모델 생성
model = Model(inputs=inputs, outputs=outputs)
*이 글은 제로베이스 데이터 취업 스쿨의 강의 자료 일부를 발췌하여 작성되었습니다.