tf.data

강민정·2023년 7월 17일

Deep Learning

목록 보기
3/8
post-thumbnail

tf.data 모듈

  • TensorFlow에서 데이터 입력 파이프라인을 구축하기 위한 기능을 제공하는 모듈
  • 모델 학습/평가를 위한 데이터셋을 제공(feeding)하기 위한 모듈
  • 구성요소
    • tf.data.Dataset
      • 데이터셋을 나타내는 클래스
      • 입력 소스의 제공 형태, 어떤 처리를 하는지에 따라 다양한 하위클래스들이 제공
      • 배열, 파일시스템의 파일, CSV 파일, TFRecord 파일 등을 데이터소스로 사용할 수 있음
      • 데이터를 반복(iterate)할 수 있는 iterator를 제공하며, 이 iterator를 통해 데이터를 모델에 공급
    • 데이터 변환 함수
      • map() : Dataset이 제공하는 원소를 처리해서 변환된 원소를 제공
      • filter() : Dataset이 제공하는 원소중 특정 조건을 만족하는(True)인 원소들만 제공
      • batch(size)
      • shuffle(buffer크기) : dataset의 원소들의 순서를 섞음
        • buffer 크기는 섞는 공간의 크기로 데이터보다 크거나 같으면 완전셔플, 적으면 일부만 가져와서 섞어 완전셔플이 안됨
      • **repeat(count)** : 전체 데이터를 한번 다 제공한 뒤 다시 데이터를 제공
    • 데이터 소스
      • tf.data.Dataset.from_tensor_slices() 함수를 사용하여 NumPy 배열이나 텐서를 데이터소스로 사용
      • tf.data.TextLineDataset() 함수를 사용하여 텍스트 파일을 데이터소스로 사용

Dataset 메소드

  • Tensor Type
    • TensorFlow의 기본 data 자료구조
    • tensorflow를 Tensor를 이용해 데이터를 관리
    • tensorflow의 모델이 학습, 평가할 때 사용하는 데이터셋(train dataset, validation dataset, test dataset)은 tf.Tensor 타입이어야 함
  • Tensor -> ndarray로 변환
    a = t.numpy()
    array([1., 2., 3.], dtype=float32)
  • ndarray/list => Tensor
    tf.constant(np.arange(10))
    tf.convert_to_tensor(a)
    <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.], dtype=float32)>

Dataset 예제

Dataset 생성

  • 0 ~ 9 정수 => input data
raw_data1 = np.arange(10)
raw_data1

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

→ 대상 raw dataset이 메모리에 있는 ndarray일 경우

⇒ Raw dataset으로 부터 데이터를 읽어 들이는 기능을 제공하는 dataset 생성


dataset = tf.data.Dataset.from_tensor_slices(raw_data1)
print(type(dataset))

<class 'tensorflow.python.data.ops.from_tensor_slices_op._TensorSliceDataset'>

Dataset은 Iterable타입

⇒ 반환 : tf.Tensor 타입

# dataset[0]으로는 조회X
for data in det:
    print(data)

  • dataset에서 5개만 조회하고 싶을 때 (일부 데이터를 확인할 때 사용)
for data in dataset.take(5):
    print(data)

0개의 댓글