[python] CRNN 한글 모델 학습하기(1) - AIhub 데이터셋

apphia·2021년 11월 26일
post-thumbnail
  1. 최상위 폴더에 deep-text-recognition-benchmark 클론해오기
  2. 최상위 폴더에 kor_dataset 폴더 생성 후, 데이터셋들을 해당 폴더로 옮기기
  3. AIhub 데이터셋 학습하기
  4. Finetuning 데이터셋 추가 학습하기
  5. 완성된 pretrained model 테스트하기

데이터셋 수집

최상위폴더
  |--deep-text-recognition-benchmark # 클론해오기
  |--kor_dataset
    |--aihub_data
      |--htr
        |--images # 여기에 AIhub 필기체 이미지파일들 넣기
        |--handwriting_data_info1.json # 라벨링 파일
      |--ocr
        |--images # 여기에 AIhub 인쇄체 이미지파일들 넣기
        |--printed_data_info.json # 라벨링 파일

데이터셋 가공

  • 최상위폴더에 aihub_dataset.py를 만들어준다.
  • 해당 내용은 여기를 참고하였다.
# aihub_dataset.py
import json
import random
import os
from tqdm import tqdm

# htr / ocr
data_type = 'htr' 
# handwriting_data_info1.json / printed_data_info.json
labeling_filename = 'handwriting_data_info1.json'

## Check Json File
file = json.load(open(f'./kor_dataset/aihub_data/{data_type}/{labeling_filename}'))

## Separate dataset - train, validation, test
image_files = os.listdir(f'./kor_dataset/aihub_data/{data_type}/images/') 
total = len(image_files)

random.shuffle(image_files)

n_train = int(len(image_files) * 0.7)
n_validation = int(len(image_files) * 0.15)
n_test = int(len(image_files) * 0.15)

print(n_train, n_validation, n_test)

train_files = image_files[:n_train]
validation_files = image_files[n_train:n_train+n_validation]
test_files = image_files[-n_test:]

## Separate image id - train, validation, test
train_img_ids = {}
validation_img_ids = {}
test_img_ids = {}

for image in file['images']: # {filename}: {image id}
  if image['file_name'] in train_files:
    train_img_ids[image['file_name']] = image['id']
  elif image['file_name'] in validation_files:
    validation_img_ids[image['file_name']] = image['id']
  elif image['file_name'] in test_files:
    test_img_ids[image['file_name']] = image['id']

## Annotations - train, validation, test 
train_annotations = {f:[] for f in train_img_ids.keys()} # {image id}: []
validation_annotations = {f:[] for f in validation_img_ids.keys()}
test_annotations = {f:[] for f in test_img_ids.keys()}

train_ids_img = {train_img_ids[id_]:id_ for id_ in train_img_ids}
validation_ids_img = {validation_img_ids[id_]:id_ for id_ in validation_img_ids}
test_ids_img = {test_img_ids[id_]:id_ for id_ in test_img_ids}

for idx, annotation in enumerate(file['annotations']):
  if idx % 5000 == 0:
    print(idx,'/',len(file['annotations']),'processed')
  if annotation['image_id'] in train_ids_img:
    train_annotations[train_ids_img[annotation['image_id']]].append(annotation)
  elif annotation['image_id'] in validation_ids_img:
    validation_annotations[validation_ids_img[annotation['image_id']]].append(annotation)
  elif annotation['image_id'] in test_ids_img:
    test_annotations[test_ids_img[annotation['image_id']]].append(annotation)

## Write json files
with open(f'{data_type}_train_annotation.json', 'w') as file:
  json.dump(train_annotations, file)
with open(f'{data_type}_validation_annotation.json', 'w') as file:
  json.dump(validation_annotations, file)
with open(f'{data_type}_test_annotation.json', 'w') as file:
  json.dump(test_annotations, file)

## Make gt_xxx.txt files
data_root_path = f'./kor_dataset/aihub_data/{data_type}/images/'
save_root_path = f'./deep-text-recognition-benchmark/{data_type}_data/'

obj_list = ['test', 'train', 'validation']
for obj in obj_list:
  total_annotations = json.load(open(f'./{data_type}_{obj}_annotation.json'))
  gt_file = open(f'{save_root_path}gt_{obj}.txt', 'w')
  for file_name in tqdm(total_annotations):
    annotations = total_annotations[file_name]
    for idx, annotation in enumerate(annotations):
      text = annotation['text']
      gt_file.write(f'{obj}/{file_name}\t{text}')
  • 위 파일을 실행시키면,
python3 aihub_dataset.py
  • 아래와 같은 파일 구조가 만들어진다.
최상위폴더
  |--deep-text-recognition-benchmark # 클론해오기
    |--htr_data
      |--gt_test.txt
      |--gt_train.txt
      |--gt_validation.txt
    |--...
  |--kor_dataset
    |--aihub_data
      |--htr
        |--images # 여기에 AIhub 필기체 이미지파일들 넣기
        |--handwriting_data_info1.json # 라벨링 파일
      |--ocr
        |--images # 여기에 AIhub 인쇄체 이미지파일들 넣기
        |--printed_data_info.json # 라벨링 파일
  |--aihub_dataset.py
  |--...
  • 생성된 gt 파일들을 기반으로 lmdb 데이터를 생성해주어야 한다.
cd deep-text-recognition-benchmark/htr_data
  • 이를 위해 kor_dataset에 있는 이미지 파일들을 가져와서 train, validation, test 폴더에 분리해주도록 한다.
# get_images.py
import shutil

data_root_path = '../../kor_dataset/aihub_data/htr/images/'
save_root_path = './images/'

# copy images from dataset directory to current directory
shutil.copytree(data_root_path, save_root_path)

# separate dataset : train, validation, test
obj_list = ['train', 'test', 'validation']
for obj in obj_list:
  with open(f'gt_{obj}.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
      file_path = line.split('.png')[0]
      file_name = file_path.split('/')[1] + '.png'
      res = shutil.move(save_root_path+file_name, f'./{obj}/')
  • get_images.py를 실행시키기 전에 반드시 현재 폴더에 train, test, validation 폴더를 만들어주도록 한다.
mkdir train test validation
python3 get_images.py
  • 이제 lmdb 데이터를 만들 수 있다.
cd ../../
python3 ./deep-text-recognition-benchmark/create_lmdb_dataset.py \
    --inputPath ./deep-text-recognition-benchmark/htr_data/ \
    --gtFile ./deep-text-recognition-benchmark/htr_data/gt_train.txt \
    --outputPath ./deep-text-recognition-benchmark/htr_data_lmdb/train

python3 ./deep-text-recognition-benchmark/create_lmdb_dataset.py \
    --inputPath ./deep-text-recognition-benchmark/htr_data/ \
    --gtFile ./deep-text-recognition-benchmark/htr_data/gt_validation.txt \
    --outputPath ./deep-text-recognition-benchmark/htr_data_lmdb/validation

인쇄체 역시 htrocr로 수정하여 위와 같이 진행하면 된다.

학습

  • 학습 순서는 인쇄체 > 필기체 순으로 진행하였다.
  • batch_size나 batch_max_length는 본인 데이터셋에 맞게 적절히 조절해주도록 한다.
  • CUDA_VISIBLE_DIVICES나 workers 역시 본인 gpu 환경에 맞게 적절한 값을 넣어주면 된다.
  • 그 외 기타 옵션들은 deep-text-recognition-benchmark/train.py를 열어서 수정한다.

train.py, test.py 수정하기

  • 여기서는 select_data, batch_ratio, character의 default 값만 아래와 같이 변경해주고, 나머지 argument들은 CUDA 명령어에 옵션으로 넣어주었다.
  • character의 경우, 추가로 인식하게 만들고싶은 글자들을 넣어주면 된다. (현재 ' " : ; 와 같은 몇가지 특수문자들이나 일부 한글은 포함되지 않은 상태이다. 이런 글자들도 추가하고 싶다면 character의 default값에 추가로 넣어주면 된다.)
parser.add_argument('--select_data', type=str, default='/',
                        help='select training data (default is MJ-ST, which means MJ and ST used as training data)')
parser.add_argument('--batch_ratio', type=str, default='1',
                        help='assign ratio for each selected data in the batch')
parser.add_argument('--character', type=str,
            default='0123456789abcdefghijklmnopqrstuvwxyz가각간갇갈감갑값갓강갖같갚갛개객걀걔거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀귓규균귤그극근글긁금급긋긍기긴길김깅깊까깍깎깐깔깜깝깡깥깨꺼꺾껌껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꾼꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냇냉냐냥너넉넌널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐댓더덕던덜덟덤덥덧덩덮데델도독돈돌돕돗동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿링마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몬몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭘뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벨벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브븐블비빌빔빗빚빛빠빡빨빵빼뺏뺨뻐뻔뻗뼈뼉뽑뿌뿐쁘쁨사삭산살삶삼삿상새색샌생샤서석섞선설섬섭섯성세섹센셈셋셔션소속손솔솜솟송솥쇄쇠쇼수숙순숟술숨숫숭숲쉬쉰쉽슈스슨슬슴습슷승시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액앨야약얀얄얇양얕얗얘어억언얹얻얼엄업없엇엉엊엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷옹와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡잣장잦재쟁쟤저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쩔쩜쪽쫓쭈쭉찌찍찢차착찬찮찰참찻창찾채책챔챙처척천철첩첫청체쳐초촉촌촛총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칫칭카칸칼캄캐캠커컨컬컴컵컷케켓켜코콘콜콤콩쾌쿄쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택터턱턴털텅테텍텔템토톤톨톱통퇴투툴툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔팝패팩팬퍼퍽페펜펴편펼평폐포폭폰표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홈홉홍화확환활황회획횟횡효후훈훌훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘?!.,()',
help='character label')

학습 시작하기

# 인쇄체(ocr)
CUDA_VISIBLE_DEVICES=0 python3 ./deep-text-recognition-benchmark/train.py \
    --train_data ./deep-text-recognition-benchmark/ocr_data_lmdb/train \
    --valid_data ./deep-text-recognition-benchmark/ocr_data_lmdb/validation \
    --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \
    --batch_size 512 --batch_max_length 200 --data_filtering_off --workers 0 \
    --num_iter 100000 --valInterval 100
  • 위 명령어를 실행시키면 아래 그림과 같이 학습이 진행된다.
  • 학습이 완료되면 최상위폴더에 saved_models라는 폴더가 생성되고, 그 아래 pretrained 모델이 저장되어 있다.
최상위폴더
  |--deep-text-recognition-benchmark # 클론해오기
    |--htr_data
      |--train
      |--test
      |--validation
      |--get_images.py
      |--gt_test.txt
      |--gt_train.txt
      |--gt_validation.txt
    |--ocr_data
      |--train
      |--test
      |--validation
      |--get_images.py
      |--gt_test.txt
      |--gt_train.txt
      |--gt_validation.txt
    |--...
  |--kor_dataset
    |--aihub_data
      |--htr
        |--images # 여기에 AIhub 필기체 이미지파일들 넣기
        |--handwriting_data_info1.json # 라벨링 파일
      |--ocr
        |--images # 여기에 AIhub 인쇄체 이미지파일들 넣기
        |--printed_data_info.json # 라벨링 파일
  |--saved_models
    |--TPS-ResNet-BiLSTM-CTC-Seed1111
      |--best_accuracy.pth
      |--best_norm_ED.pth
      |--log_dataset.txt
      |--log_train.txt
      |--opt.txt
  |--aihub_dataset.py
  |--...
  • 우선 이후 학습에서의 overwrite을 방지하기 위해 TPS-ResNet-BiLSTM-CTC-Seed1111 폴더명을 kocrnn_ocr로 수정한다.
cd saved_models
mv TPS-ResNet-BiLSTM-CTC-Seed1111 kocrnn_ocr
  • 필기체(htr)의 추가학습을 위해 kocrnn_ocr 폴더 아래 있는 best_accuracy.pth 파일을 따로 복사해둔다.
cd kocrnn_ocr
cp best_accuracy.pth ../../pretrained_models/kocrnn.pth
cd ../../
  • 복사해둔 kocrnn.pth 모델 위에 필기체(htr)를 추가학습 시켜준다.
# 인쇄체(htr)
CUDA_VISIBLE_DEVICES=0 python3 ./deep-text-recognition-benchmark/train.py \
    --train_data ./deep-text-recognition-benchmark/htr_data_lmdb/train \
    --valid_data ./deep-text-recognition-benchmark/htr_data_lmdb/validation \
    --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \
    --batch_size 512 --batch_max_length 200 --data_filtering_off --workers 0 \
    --saved_model ./pretrained_models/kocrnn.pth --num_iter 100000 --valInterval 100
  • 필기체의 학습이 완료되면, 인쇄체에서와 같이 저장된 폴더의 이름을 변경하고, 해당 폴더의 best_accuracy.pth를 다시 복사해온다.
cd saved_models
mv TPS-ResNet-BiLSTM-CTC-Seed1111 kocrnn_ocr_htr
cd kocrnn_ocr_htr
cp best_accuracy.pth ../../pretrained_models/kocrnn.pth
cd ../../
  • 위 작업들이 모두 끝나면 아래와 같은 파일 구조가 만들어진다.
최상위폴더
  |--deep-text-recognition-benchmark # 클론해오기
    |--htr_data
      |--train
      |--test
      |--validation
      |--get_images.py
      |--gt_test.txt
      |--gt_train.txt
      |--gt_validation.txt
    |--ocr_data
      |--train
      |--test
      |--validation
      |--get_images.py
      |--gt_test.txt
      |--gt_train.txt
      |--gt_validation.txt
    |--...
  |--kor_dataset
    |--aihub_data
      |--htr
        |--images # 여기에 AIhub 필기체 이미지파일들 넣기
        |--handwriting_data_info1.json # 라벨링 파일
      |--ocr
        |--images # 여기에 AIhub 인쇄체 이미지파일들 넣기
        |--printed_data_info.json # 라벨링 파일
  |--saved_models
    |--kocrnn_ocr
      |--best_accuracy.pth
      ...
    |--kocrnn_ocr_htr
      |--best_accuracy.pth
      ...
  |--pretrained_models
    |--kocrnn.pth
  |--aihub_dataset.py
  |--...


[References]

profile
내가 보려고 정리하는 공부 블로그

1개의 댓글

comment-user-thumbnail
2022년 9월 20일

안녕하세요! 현재 신분증 OCR 구현하기 위해서 해당 글을 참고하고 있는데
AI hub 데이터 셋이 변경된 것 같은데 혹시 변경된 데이터에서 그대로 진행하면 오류가 나는데 해결 방법 아실까요?

답글 달기