[python] CRNN 한글 모델 학습하기(3) - model 성능 테스트

apphia·2021년 12월 1일
post-thumbnail
  1. 최상위 폴더에 deep-text-recognition-benchmark 클론해오기
  2. 최상위 폴더에 kor_dataset 폴더 생성 후, 데이터셋들을 해당 폴더로 옮기기
  3. AIhub 데이터셋 학습하기
  4. Finetuning 데이터셋 추가 학습하기
  5. 완성된 pretrained model 테스트하기

이번 포스팅에서는 지금까지 학습한 모델을 가지고 테스트하는 과정에 대해 정리한다.

현재 파일구조는 아래와 같다.

최상위폴더
  |--deep-text-recognition-benchmark
    |--data 폴더들...
    |--lmdb 폴더들...
    |--demo.py
    |--train.py
    |--test.py
    |--기타 파일들...
  |--kor_dataset
    |--aihub_data
      |--htr
      |--ocr
   |--finetuning_data
      |--made1
      |--made2
  |--saved_models
    |--kocrnn_ocr
    |--kocrnn_ocr_htr
    |--kocrnn_ocr_htr_made1
    |--kocrnn_ocr_htr_made12
  |--pretrained_models
    |--kocrnn.pth
  |--aihub_dataset.py
  |--finetuning_dataset.py
  |--기타 파일들...

demo.py 실행

  • 위 파일들 중 pretrained model을 테스트하는데 관여하는 파일은 demo.py이다. 일단은 그냥 실행시켜보도록 한다.
CUDA_VISIBLE_DEVICE=0 python3 ./deep-text-recognition-benchmark/demo.py \
    --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \
    --image_folder ./deep-text-recognition-benchmark/htr_data/test/ \
    --saved_model ./pretrained_models/kocrnn.pth --workers 0
  • demo.py를 그냥 돌려보면 콘솔 창에는 아래와 같은 구조로 출력된다.
{img path/name} | {predicted characters} | {confidence score}

demo.py 수정

  • 좀 더 명확한 성능지표를 만들기 위해 여기에 CER을 계산할 수 있는 함수를 추가시켜보도록 한다.

  • CER에 대한 자세한 내용은 여기를 참고하였다.

CER(Character Error Rate)

  • 예측한 텍스트가 s1, 원래 텍스트가 s2라고 했을 때, s1이 s2가 되는 과정에서 발생하는 첨자, 오자, 탈자 정보가 필요하다.
  • Normalization 공식: (s1을 s2로 변환하는 과정에서 발생하는 첨자+오자+탈자)/(s2의 길이)
  • levenshtein distance를 이용해 첨자+오자+탈자를 계산할 수 있다.
  • 위 정보들을 바탕으로 demo.py를 수정한다.

1. levenshtein distance를 계산해주는 함수를 추가해준다.

def levenshtein(s1, s2, debug=False):
    if len(s1) < len(s2):
        return levenshtein(s2, s1, debug)

    if len(s2) == 0:
        return len(s1)

    previous_row = range(len(s2) + 1)    
    for i, c1 in enumerate(s1):
        current_row = [i + 1]
        for j, c2 in enumerate(s2):
            insertions = previous_row[j + 1] + 1
            deletions = current_row[j] + 1
            substitutions = previous_row[j] + (c1 != c2)
            current_row.append(min(insertions, deletions, substitutions))

        previous_row = current_row

    return previous_row[-1]

2. demo함수 수정하기

  • demo 함수에서 log파일을 작성하는 부분을 찾아 적절하게 수정해준다.
  • 특히 for문 안에서 테스트셋들에 대한 cer을 계산할 수 있도록 levenshtein 함수를 호출해준다.
# def demo(opt)

# 추가) dashed_line 숫자 변경: 80->120
# 추가) head에 character_error_rate 추가
log = open(opt.log_filename, 'a')
dashed_line = '-' * 120
head = f'{"image_path":25s}\t{"ground_truth":15s}\t{"predicted_labels":15s}\tconfidence score\tcharacter error rate'
            
print(f'{dashed_line}\n{head}\n{dashed_line}')
log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

preds_prob = F.softmax(preds, dim=2)
preds_max_prob, _ = preds_prob.max(dim=2)
# 추가) total_err, total_len
total_err = 0
total_len = 0

for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
    if 'Attn' in opt.Prediction:
        pred_EOS = pred.find('[s]')
        pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
        pred_max_prob = pred_max_prob[:pred_EOS]

        # calculate confidence score (= multiply of pred_max_prob)
        confidence_score = pred_max_prob.cumprod(dim=0)[-1]
    
    # 추가) 1. extract labels from test images
    with open(opt.label_test, 'r') as f:
        lines = f.readlines()
        for line in lines:
            info = line.split('.png\t')
            file_name = info[0] + '.png'
            label = info[1].strip() 
            if file_name.split('/')[-1] == img_name.split('/')[-1]: break
            
    # 추가) 2. calculate CER
    # CER: Ground Truth(img_name)를 OCR 출력(pred)로 변환하는데 필요한 최소 문자 수준 작업 수
    # CER = 100 * [1 - (탈자개수 + 오자개수 + 첨자개수) / 원본글자수]
    error_num = levenshtein(pred, label, debug=True) # 오자 + 탈자 + 첨자
    cer = error_num / len(label)                     # 현재 텍스트에 대한 cer
    total_len += len(label)                          # 전체 텍스트(ground truth)의 길이
    total_err += error_num                              # 전체 텍스트의 오자+탈자+첨자 수
    total_cer = total_err / total_len                   # 전체 텍스트에 대한 cer

    print(f'{img_name:25s}\t{label:15s}\t{pred:15s}\t{confidence_score:0.4f}\t{cer:0.4f}')
    log.write(f'{img_name:25s}\t{label:15s}\t{pred:15s}\t{confidence_score:0.4f}\t{cer:0.4f}\n')

print(f'total CER = {total_cer}')
log.write(f'total_CER = {total_cer}\n')
log.close()

3. parser 수정하기

  • train.py와 test.py와 마찬가지로 character 옵션을 수정해주어야 한다.
  • 테스트 결과로 저장할 log 파일명을 argument로 추가한다 : --log_filename
  • 테스트셋의 라벨링 파일을 argument로 추가한다 : --label_test
parser.add_argument('--character', type=str,
            default='0123456789abcdefghijklmnopqrstuvwxyz가각간갇갈감갑값갓강갖같갚갛개객걀걔거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀귓규균귤그극근글긁금급긋긍기긴길김깅깊까깍깎깐깔깜깝깡깥깨꺼꺾껌껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꾼꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냇냉냐냥너넉넌널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐댓더덕던덜덟덤덥덧덩덮데델도독돈돌돕돗동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿링마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몬몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭘뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벨벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브븐블비빌빔빗빚빛빠빡빨빵빼뺏뺨뻐뻔뻗뼈뼉뽑뿌뿐쁘쁨사삭산살삶삼삿상새색샌생샤서석섞선설섬섭섯성세섹센셈셋셔션소속손솔솜솟송솥쇄쇠쇼수숙순숟술숨숫숭숲쉬쉰쉽슈스슨슬슴습슷승시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액앨야약얀얄얇양얕얗얘어억언얹얻얼엄업없엇엉엊엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷옹와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡잣장잦재쟁쟤저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쩔쩜쪽쫓쭈쭉찌찍찢차착찬찮찰참찻창찾채책챔챙처척천철첩첫청체쳐초촉촌촛총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칫칭카칸칼캄캐캠커컨컬컴컵컷케켓켜코콘콜콤콩쾌쿄쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택터턱턴털텅테텍텔템토톤톨톱통퇴투툴툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔팝패팩팬퍼퍽페펜펴편펼평폐포폭폰표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홈홉홍화확환활황회획횟횡효후훈훌훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘?!.,()',
help='character label')
parser.add_argument('--label_test', required=True, help='path to labels.txt for test images')
parser.add_argument('--log_filename', required=True, help='log file name')

테스트 & 성능지표 보기

1. 기존의 gt_test.txt와 test 데이터셋을 이용할 경우

# 최상위 폴더로 이동
CUDA_VISIBLE_DEVICE=0,1,2 python3 ./deep-text-recognition-benchmark/demo.py \
    --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \
    --image_folder ./deep-text-recognition-benchmark/htr_data/test/ \
    --label_test ./deep-text-recognition-benchmark/htr_data/gt_test.txt \
    --log_filename ./htr_test_log.txt \
    --saved_model ./pretrained_models/kocrnn.pth

2. 별도의 테스트 데이터셋 추가한 경우

  • test/images 폴더에 테스트 이미지들을 넣고, test/labels.txt로 라벨링 파일을 위치시킨다.
CUDA_VISIBLE_DEVICE=0,1,2 python3 ./deep-text-recognition-benchmark/demo.py \
    --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction CTC \
    --image_folder ./test/images/ \
    --label_test ./test/labels.txt \
    --log_filename ./test/test_log.txt \
    --saved_model ./pretrained_models/kocrnn.pth

2번을 수행했을 때 log 파일 결과 이미지

  • 오른쪽부터 | 파일이름 | 실제 텍스트(라벨) | 예측한 텍스트 | confidence score | character error rate | 순으로 출력된다.
    test_log.txt

완성된 소스 코드는 여기에 저장해두었다.


[References]

profile
내가 보려고 정리하는 공부 블로그

1개의 댓글

comment-user-thumbnail
2021년 12월 9일

imgW imgH가 학습할때 기본 32 / 100으로 고정되어있는데 그 부분은 테스트할 때 이상이 없었는지 궁금합니다. 학습에서는 좋은 acc가 나오는데 demo만하면 제대로 인식을 못하네요 ㅠ

답글 달기