tfrecord 압축 옵션 관련

spring·2021년 7월 31일
0

tfrecord를 주로 사용하는 편이다. 아무래도 만들땐 귀찮아도 만들고나면 사용이 편리한 장점도 있고 분할된 파일들을 읽는것 보다 더 빠르게 로드할 수 있으니 tensorflow를 사용할 땐 거의 항상 만들어 사용하는데 사용중 궁금한 점이 생겨서 실험한 내용을 정리해둔다.

이미지를 미리 압축(인코딩)하고 tfrecord로 묶는 방식이 있고 이미지를 압축하지 않고 tfrecord를 만들고 그 후에 전체를 압축하는 방법이 있다.

이미지 압축은 opencv로 처리하니 jpg, png, webp등이 있고 비손실환경에서 작업하려면 webp를 사용하는게 용량이 작아서 좋다.

그런데 문제는 많은 이미지들을 로드할때 매번 decode를 하니 이것도 시간이 만만치 않다. 그래서 decode를 하지 않도록 원본 이미지를 직렬화하고 전체 압축을 하는 방법과 시간을 비교해본다.

압축 포맷별 tfrecord 생성 시간 비교

GZIP: 146s  6.7G
ZLIB: 149s  6.7G
jpg: 144s   733M
png: 258s   5.2G
webp: 3133s 4.2G

webp의 경우 실제로 다른 압축 포맷보다 10배 이상의 encoding 시간을 필요로 한다. 하지만 비손실 압축중 가장 높은 압축률을 보여준다.

압축 포맷별 tfrecord 읽기 시간 비교

GZIP: 93s  
ZLIB: 90s  
jpg: 199s   
png: 235s   
webp: 228s 

결과는 예상한대로 이미지 무손실 압축은 webp가 읽기가 빠르고(용량도 작고) 전체 파일을 무손실로 빠르게 읽으려면 ZLIB이 가장 우세하였다.

아래는 실험에 사용한 코드이다.

perf_write_tfrecord.py

import tensorflow as tf
from time import time
import pandas as pd
from tqdm import tqdm
import cv2
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def to_tfrecord(path_input:str, path_output:str,compress:str=""):
    img_input = cv2.imread(path_input)
    img_output = cv2.imread(path_output)
    height, width, channel = img_input.shape
    if compress:    #.jpg | .png | .webp    
        encoded_output = cv2.imencode("."+compress, img_output)[1].tobytes()
        encoded_input = cv2.imencode("."+compress, img_input)[1].tobytes()
    else:
        encoded_input = img_input.tobytes()
        encoded_output = img_output.tobytes()
    
    feature_dict = {}
    feature_dict['height'] = _int64_feature(value=height)
    feature_dict['width'] = _int64_feature(value=width)
    feature_dict['channel'] = _int64_feature(value=channel)
    feature_dict['image'] = _bytes_feature(encoded_input)
    feature_dict['target'] = _bytes_feature(encoded_output)
    tf_data = tf.train.Example(features=tf.train.Features(feature=feature_dict))
    return tf_data

def save_tfrecord(filename: str,compress1:str="",compress2:str=""):
    train_csv = pd.read_csv('./train.csv')
    train_all_input_files = './train_input_img/' + train_csv['input_img']
    train_all_label_files = './train_label_img/' + train_csv['label_img']
    options = tf.io.TFRecordOptions(compression_type=compress1) # ZLIB GZIP
    tf_writer = tf.io.TFRecordWriter(filename, options=options)
    for i in tqdm(range(len(train_all_input_files))):
        tf_example = to_tfrecord(train_all_input_files[i], train_all_label_files[i],compress2)
        tf_writer.write(tf_example.SerializeToString())
    tf_writer.close()

def main():
    start=time()
    save_tfrecord("traindata(gzip).tfrecord","GZIP","")
    print("GZIP: ",time()-start)
    start=time()
    save_tfrecord("traindata(zlib).tfrecord","ZLIB","")
    print("ZLIB: ",time()-start)
    start=time()
    save_tfrecord("traindata(jpg).tfrecord","","jpg")
    print("jpg: ",time()-start)
    start=time()
    save_tfrecord("traindata(png).tfrecord","","png")
    print("png: ",time()-start)
    start=time()
    save_tfrecord("traindata(webp).tfrecord","","webp")
    print("webp: ",time()-start)

if __name__ == "__main__":
    main()
    

perf_read_tfrecord.py

import tensorflow as tf
from time import time
import pandas as pd
from tqdm import tqdm
import cv2
import numpy as np

def decode_fn(record_bytes):
    return tf.io.parse_single_example(
            record_bytes,
            {
                'height': tf.io.FixedLenFeature([], dtype=tf.int64),  
                'width' : tf.io.FixedLenFeature([], dtype=tf.int64),  
                'channel' : tf.io.FixedLenFeature([], dtype=tf.int64),  
                'image' : tf.io.FixedLenFeature([], dtype=tf.string),  
                'target': tf.io.FixedLenFeature([], dtype=tf.string) 
            }
    )


def read_tfrecord(path_input:str,compression_type:str):
    tfrd = tf.data.TFRecordDataset(path_input,compression_type=compression_type).map(decode_fn)
    images_input=[]
    images_output=[]
    for e in tfrd:
        height = int(e['height'])
        width = int(e['width'])
        channel = int(e['channel'])
        img_input = np.frombuffer(e['image'].numpy(), np.uint8)
        img_output = np.frombuffer(e['target'].numpy(), np.uint8)
        if compression_type:
            img_input=img_input.reshape(height,width,channel)
            img_output=img_output.reshape(height,width,channel)
        else:
            img_input = cv2.imdecode(img_input, cv2.IMREAD_UNCHANGED)
            img_output = cv2.imdecode(img_output, cv2.IMREAD_UNCHANGED)
        images_input.append(img_input)
        images_output.append(img_output)

if __name__ == "__main__":
    start=time()
    read_tfrecord('traindata(gzip).tfrecord',"GZIP")
    print("GZIP: ",time()-start)
    start=time()
    read_tfrecord('traindata(ZLIB).tfrecord',"ZLIB")
    print("ZLIB: ",time()-start)
    start=time()
    read_tfrecord('traindata(jpg).tfrecord',"")
    print("jpg: ",time()-start)
    start=time()
    read_tfrecord('traindata(png).tfrecord',"")
    print("png: ",time()-start)
    start=time()
    read_tfrecord('traindata(webp).tfrecord',"")
    print("webp: ",time()-start)
profile
Researcher & Developer @ NAVER Corp | Designer @ HONGIK Univ.

0개의 댓글