Tensorflow tf.Variables error

jswboseok·2023년 2월 1일
0

문제

모델 A의 loss function에 다른 모델 B의 output을 피드백삼아 loss function을 구현할 시 모델 A를 학습하는 과정에서 지속적으로 모델 B를 생성할 때 ValueError가 발생하였다.

본인은 다음과 같이 VGG19 모델의 output을 이용하여 loss function을 구현하였다.

class VGGLOSS(object):
    def __init__(self, input_shape):
        self.input_shape = input_shape
        
    # computes VGG loss or content loss
    def vgg_loss(self, y_true, y_pred):
    	vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=self.input_shape)
        vgg19.trainable = False
        # Make trainable as False
        for l in vgg19.layers:
            l.trainable = False
        model = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block5_conv4').output)
        model.trainable = False
        
        return K.mean(K.square(model(y_true) - model(y_pred)))

이를 model.compile에 넘기면 다음과 같이 넘기게 된다.

vggloss = VGGLOSS(input_shape)
model.compile(loss=vggloss.vgg_loss, optimizer=Adam(lr=learning_rate) )

이후에 학습을 실행하게 되면 다음과 같은 에러가 발생하였다.

Error log

ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.

해결책

Error log에 주어진 링크에서 힌트를 얻을 수 있을까 하여 링크로 가보았다.
tf.Variables 만들기 섹션으로 가보았다.

여러 Keras 모델과 함께 사용
동일한 Function에 다른 모델 인스턴스를 전달할 때에도 ValueError: tf.function only supports singleton tf.Variables created on the first call.이 발생할 수 있습니다.

이 오류는 Keras 모델(입력 형상이 정의되지 않음)과 Keras 레이어가 처음 호출될 때 tf.Variables를 만들기 때문에 발생합니다. 이미 호출된 Function 내에서 이러한 변수를 초기화하려고 할 수도 있습니다. 이 오류를 방지하려면 model.build(input_shape)를 호출하여 모델을 훈련하기 전에 모든 가중치를 초기화합니다.

모델이 훈련하기 전에 모든 가중치를 초기화해야한다는 문구가 있다. 위 사이트뿐만 아니라 다른 블로그나 stackoverflow 글들을 참고하면서 그 이유로 모델 학습 과정에서 다른 모델을 생성할 때 나타나는 문제라고 판단하였다. 따라서 loss function 코드를 다음과 같이 바꾸었다.

class VGGLOSS(object):
    def __init__(self, input_shape):
        self.input_shape = input_shape
        vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=self.input_shape)
        vgg19.trainable = False
        # Make trainable as False
        for l in vgg19.layers:
            l.trainable = False
        self.model = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block5_conv4').output)
        self.model.trainable = False
        
    # computes VGG loss or content loss
    def vgg_loss(self, y_true, y_pred):
        return K.mean(K.square(self.model(y_true) - self.model(y_pred)))

이렇게 되면 VGG19 모델은 초기 객체가 생성될때만 생성되고 학습 과정에서는 객체에 정의된 모델을 사용하기 때문에 오류가 나지 않았다.

참고

profile
냠냠

0개의 댓글