이 글은 머신러닝을 1도 모르는 한 불쌍한 대학생이 학부에서 텀 프로젝트를 해내기 위해 논문 코드를 하나하나 뜯어보는 글이다. 배우면서 쓰는 글이다 보니 정확성이 많이 떨어질 수 있음을 양해 바란다.
주제는 crnn이라는 모델을 활용한 ocr 구현이다.
출처
import os
import time
import numpy as np
import tensorflow as tf
from scipy.misc import imread, imresize, imsave
from tensorflow.contrib import rnn
from data_manager import DataManager
from utils import (
sparse_tuple_from,
resize_image,
label_to_array,
ground_truth_to_word,
levenshtein,
)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
class CRNN(object):
def __init__(
self,
batch_size, # 배치 사이즈 지정
model_path, # 파라미터와 체크포인트 파일 저장, 로드할 디렉토리 경로 지정
examples_path, # 학습 데이터가 있는 디렉토리 경로 지정
max_image_width, #입력 이미지 최대 너비 지정
train_test_ratio, # 학습, 테스트 데이터 나누는 비율 설정
restore, # 모델을 이전에 학습한 상태로 복원할지 여부 결정 플래그
char_set_string, # 모델이 인식해야 할 문자 집합 지정
use_trdg, #True 또는 False 값을 가지며, 텍스트 데이터 증강 기술인 TextRecognitionDataGenerator(TRDG)를 사용할지 여부를 결정
language, # 모델에서 사용할 언어를 지정
):
self.step = 0
self.CHAR_VECTOR = char_set_string # CHAR_VECTOR: 모델에서 인식해야 하는 문자 집합을 나타내는 문자
self.NUM_CLASSES = len(self.CHAR_VECTOR) + 1 # NUM_CLASSES: 문자 집합에 포함된 문자 수에 1을 더한 값. 모델의 출력 클래스 수를 나타냄
print("CHAR_VECTOR {}".format(self.CHAR_VECTOR))
print("NUM_CLASSES {}".format(self.NUM_CLASSES))
self.model_path = model_path #model_path: 모델 파일과 체크포인트 파일이 저장되거나 로드될 위치 지정
self.save_path = os.path.join(model_path, "ckp") # 체크포인트 파일이 저장될 위치 지정.
# os.path.join(): 파일 시스템의 경로를 조인하여 새로운 경로를 만들어내는 함수
self.restore = restore
self.training_name = str(int(time.time())) # 모델 이름 지어주기
self.session = tf.Session() # tf의 session 함수 인스턴스 생성
# Building graph
with self.session.as_default():
(
self.inputs, # 입력 데이터
self.targets, # 타겟 데이터
self.seq_len,
self.logits, # 출력
self.decoded, # 디코딩된 결과
self.optimizer, # 최적화 방법
self.acc, # 정확도
self.cost, # 손실
self.max_char_count, # 최대 문자 수
self.init, # 초기화 연산
) = self.crnn(max_image_width)
self.init.run()
with self.session.as_default():
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
# Loading last save if needed
if self.restore:
print("Restoring")
ckpt = tf.train.latest_checkpoint(self.model_path)
if ckpt:
print("Checkpoint is valid")
self.step = int(ckpt.split("-")[1])
self.saver.restore(self.session, ckpt)
# Creating data_manager
self.data_manager = DataManager(
batch_size,
model_path,
examples_path,
max_image_width,
train_test_ratio,
self.max_char_count,
self.CHAR_VECTOR,
use_trdg,
language,
)
다음 코드를 하나하나 뜯어보자.
print("CHAR_VECTOR {}".format(self.CHAR_VECTOR))
print("NUM_CLASSES {}".format(self.NUM_CLASSES))
우선 이 프린트 함수의 구조를 살펴보면:
"NUM_CLASSES {}" 부분은 출력 문자열의 형식을 정의한다. {}에는 나중에 들어갈 변수의 위치를 표시해준다.
.format(self.NUM_CLASSES) 부분은 중괄호 {}에 self.NUM_CLASSES 변수의 값을 삽입한다.
이 코드의 출력값은 다음과 같을 것이다
CHAR_VECTOR ABCDEFGHI
NUM_CLASSES 10
NUM_CLASSES는 CHAR_VECTOR의 길이에 1을 더한 값이므로 출력은 10이 된다.
다음 코드를 보자.
self.model_path = model_path
self.save_path = os.path.join(model_path, "ckp")
self.model_path와 self.save_path를 초기화한다.
self.model_path는 모델 파일과 체크포인트 파일을 저장하거나 로드할 디렉토리 경로를 나타낸다.
os.path.join() 함수는 파일 시스템 경로를 조인하고 새 경로를 생성하는 데 사용된다. 여기서는 model_path와 "ckp" 문자열을 조합하여 self.save_path에 새 경로를 할당한다. 이렇게 생성된 self.save_path는 모델의 체크포인트 파일을 저장할 디렉토리 경로를 나타낸다.
self.training_name = str(int(time.time()))
self.session = tf.Session()
이 부분은 self.training_name이라는 클래스 멤버 변수를 초기화한다.
time.time() 함수는 현재 시간을 초 단위로 반환한다. 이 값은 컴퓨터 시스템의 현재 시각에 대한 타임스탬프이다.
int(time.time())는 현재 시각의 타임스탬프를 정수형으로 변환한다.
str(int(time.time()))는 이 정수 타임스탬프를 문자열로 변환한다.
따라서 self.training_name은 현재 시간을 문자열 형태로 나타낸 것으로, 학습 중인 모델을 고유하게 식별하기 위한 이름 또는 식별자로 사용될 수 있다.
self.session = tf.Session()은 TensorFlow에서 세션 객체(tf.Session())를 생성하여 self.session에 할당한다.
텐서플로우 세션에 대하여
TensorFlow의 세션은 그래프를 실행하고 변수를 초기화하거나 학습을 진행하는 데 사용된다.
self.session은 모델을 학습하고 추론하기 위한 TensorFlow 세션을 나타내며, 이 세션을 사용하여 모델의 그래프를 실행하고 데이터를 처리할 수 있다.
요약하면, 이 부분은 현재 시간을 사용하여 학습 중인 모델을 식별하는 이름(self.training_name)을 생성하고, TensorFlow 세션 객체(self.session)를 초기화하는 역할을 한다. 이것들은 모델 식별 및 TensorFlow를 사용한 학습 및 추론에 필요한 초기 설정이다.
with self.session.as_default():
(
self.inputs,
self.targets,
self.seq_len,
self.logits,
self.decoded,
self.optimizer,
self.acc,
self.cost,
self.max_char_count,
self.init,
) = self.crnn(max_image_width)
self.init.run()
with self.session.as_default()
as_default()에 대해서:

요약하자면 as_default는 텐서플로우에서 쓰는 세션 관련 함수다. 현재 세션을 기본 세션으로 설정한다.
따라서 위의 코드는 self.session을 기본 세션으로 지정한다.
with 문에 대해서:

위의 코드에서는 self.crnn(max_image_width)에서 반환된 값들을 각각의 클래스 멤버 변수에 할당한다.
self.crnn(max_image_width)
self.crnn() 함수를 호출하여 CRNN 모델을 구성하고 초기화하는 작업을 수행한다.
max_image_width 매개변수는 모델에서 처리할 이미지의 최대 너비를 지정한다.
crnn 함수를 구현한 코드는 뒤에 가서 다룰 예정이니 이해가 안되더라도 일단 넘어가자.
self.crnn(max_image_width) 함수가 반환한 모델의 입력 데이터, 타겟 데이터, 출력(logits), 디코딩된 결과(decoded), 최적화 방법(optimizer), 정확도(acc), 손실(cost), 최대 문자 수(max_char_count), 그리고 초기화 연산(init)이 클래스 멤버 변수에 저장된다.
self.init.run()
TensorFlow의 모델을 사용하기 전에 필요한 초기화 연산을 실행하는 부분이다.
with self.session.as_default():
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
# Loading last save if needed
if self.restore:
print("Restoring")
ckpt = tf.train.latest_checkpoint(self.model_path)
if ckpt:
print("Checkpoint is valid")
self.step = int(ckpt.split("-")[1])
self.saver.restore(self.session, ckpt)
self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=10)
self.saver: Tensorflow의 saver 클래스의 인스턴스
tf.train.Saver: 모델의 상태를 저장하는 변수를 저장, 복원하는 데 사용됨
tf.global_variables(): 훈련 및 추론 과정 중 저장하고 복원해야 하는 전역 변수들
max_to_keep: 기억해야 할 checkpoints 수
if self.restore:
self.restore 변수가 True 이면
print("Restoring")
ckpt = tf.train.latest_checkpoint(self.model_path)
이 코드는 지정된 디렉토리인 self.model_path에서 가장 최근의 체크포인트 파일을 찾아 ckpt 변수에 저장한다(체크포인트 파일은 모델의 변수 상태를 저장한 파일이다).
if ckpt:
체크포인트 파일(ckpt)이 존재하는 경우(체크포인트 파일이 존재하지 않을 수 있으므로 if문으로 체크해준다).
print("Checkpoint is valid")
self.step = int(ckpt.split("-")[1])
self.saver.restore(self.session, ckpt)
checkpoint가 존재한다고 출력해 알리고, self.step에 체크포인트에서 추출한 step 번호를 저장한다.
self.saver.restore(self.session, ckpt): TensorFlow Saver 객체를 사용하여 모델의 변수를 지정된 체크포인트 파일(ckpt)에서 복원한다. 모델은 이전 훈련 상태로 돌아가며, 이전에 저장된 변수 값으로 초기화된다.
self.step = int(ckpt.split("-")[1])
개인적으로 이 부분 코드가 이해가 안 돼서 조금 더 알아봤다.
ckpt에 들어 있는 값은 체크포인트 파일의 경로이다(ckpt가 반환값을 받는 함수인 tf.train.latest_checkpoint(self.model_path)은 model_path에서 가장 최근의 체크포인트 파일을 찾아 그 "경로"를 반환환다).
체크포인트 파일은 일반적으로 하이픈과 숫자로 이루어져 있다. skpt.split("-")는 파일 경로에서 하이픈을 제거하고 숫자만을 남긴다.
[1]은 남은 숫자 목록(배열?)에서 두 번째 요소를 의미하고 이 int값을 우리는 step으로 쓰기로 했다.
self.data_manager = DataManager(
...
)
DataManager 클래스는 우리가 뒤에서 뜯어볼 data_manager.py에서 정의된 클래스이다. 우리는 코드 첫부분에서 from data_manager import DataManager를 해 줬었다.
지금까지 crnn모델의 init 함수의 코드를 분석해 보았다. 이제 다음 장에서는 본격적으로 crnn 함수의 코드를 뜯어보도록 하겠다.