
$ pip install tfx
# 로컬환경에 파이프라인 생성
$ tfx pipeline create --engine=local --pipeline_path={pipeline path}
# 기존에 있는 파이프라인을 수정(업데이트)
$ tfx pipeline update --engine=local --pipeline_path={pipeline path}
# 파이프라인을 검사 (파이프라인을 생성하거나 업데이트 하기전에 실행해보는걸 권장)
$ tfx pipeline compile --engine=local --pipeline_path={pipeline path}
# 기존에 있는 파이프라인을 삭제
$ tfx pipeline delete --engine=local --pipeline_name={삭제하고자 하는 pipeline이름}
# 지정된 오케스트레이터(여기서는 로컬 오케스트레이터)의 모든 파이프라인 나열
$ tfx pipeline list --engine=local
# 생성되어 있는 파이프라인 중 실행하고자 하는 파이프라인 실행
$ tfx run create --engine=local --pipeline_name={실행하고자 하는 pipeline이름}
# 실행중인 파이프라인 중지 (kubeflow만 지원됨)
$ tfx run terminate --run_id={run id} [--endpoint={endpoint} --engine={오케스트레이터} --iap_client_id={IAP 클라이언트 아이디} --namespace={kubeflow의 지정된 namespace}]
# 실행중인 파이프라인 실행 나열 (로컬 & Apache Beam은 지원되지 않음)
$ tfx run list --pipeline_name={pipeline name} [--endpoint={endpoint} --engine={오케스트레이터} --iap_client_id={IAP 클라이언트 아이디} --namespace={kubeflow의 지정된 namespace}]
# 파이프라인 실행의 현재 상태 확인 (로컬 & Apache Beam은 지원되지 않음)
$ tfx run status --pipeline_name={pipeline-name} --run_id={확인하고자 하는 실행의 run id} [--endpoint={endpoint} --engine={오케스트레이터} --iap_client_id={IAP 클라이언트 아이디} --namespace={kubeflow의 지정된 namespace}]
# 파이프라인 실행 삭제 (kubeflow만 지원됨)
$ tfx run delete --run_id={삭제하고자 하는 실행의 run id} [--endpoint={endpoint} --engine={오케스트레이터} --iap_client_id={IAP 클라이언트 아이디} --namespace={kubeflow의 지정된 namespace}]
• 입력: CSV, TFRecord , Avro, Parquet 및 BigQuery와 같은 외부 데이터 소스의 데이터
• 출력: 페이로드 형식에 따라 tf.Example 레코드, tf.SequenceExample 레코드 또는 proto 형식
(본 포스트에서는 CSV 데이터를 다룸)
import os
from tfx.components import CsvExampleGen
from tfx.proto import example_gen_pb2
"""
예시 데이터)
./data
ㄴ data.csv
data.csv
A B label
100 20 1
200 70 0
150 55 1
800 25 0
500 30 0
750 65 1
250 40 0
450 60 1
"""
data_root_path = "./data"
## 학습(train) 데이터와 평가(eval) 데이터가 각각 다른 디렉터리에 있을 때
#input_config = example_gen_pb2.Input(
# splits=[
# example_gen_pb2.Input.Split(name='train', pattern='train/*.csv'),
# example_gen_pb2.Input.Split(name='eval', pattern='eval/*.csv')
# ]
#)
# 입력 데이터를 3:1 비율로 학습(train)과 평가(eval)로 나누기
output_config = example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=3),
example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
])
)
example_gen = CsvExampleGen(input_base=data_root_path , output_config=output_config)
(Config 설정방법으로 Span, Date, Version, Range 등 여러 옵션이 있음)
from tfx.components import StatisticsGen
statistics_gen= StatisticsGen(
examples=example_gen.outputs['examples'],
name='statistics-gen'
)
from tfx.components import SchemaGen
schema_gen= SchemaGen(
statistics=statistics_gen.outputs['statistics']
)
from tfx.components import ExampleValidator
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=schema_gen.outputs['schema']
)
"""
preprocessing_file.py
def preprocessing_fn(inputs):
a_feat = inputs['A']
b_feat = inputs['B']
inputs['feature'] = tf.cast(A / B, tf.float32)
return inputs
"""
from tfx.components import Transform
transform = Transform(
examples=example_gen.outputs['examples'],
schema=schema_gen.outputs['schema'],
module_file={preprocessing_file.py path}
)
"""
preprocessing_file.py
import tensorflow as tf
from tfx.components.trainer.fn_args_utils import FnArgs
def _build_keras_model():
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(1,), name='feature'), # 입력 이름을 'feature'로 설정
tf.keras.layers.Dense(10, activation='relu', input_shape=(2,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
def run_fn(fn_args: FnArgs):
# 명시적 파일 경로 리스트 생성
train_files = glob.glob(fn_args.train_files[0]) # 실제 파일 리스트
eval_files = glob.glob(fn_args.eval_files[0]) # 실제 파일 리스트
# 모델을 빌드
model = _build_keras_model()
# 데이터를 로드
train_dataset = tf.data.TFRecordDataset(train_files, compression_type="GZIP")
eval_dataset = tf.data.TFRecordDataset(eval_files, compression_type="GZIP")
# 데이터셋을 전처리 및 배치단위로 변환
def _parse_fn(serialized_example):
feature_description = {
'feature': tf.io.FixedLenFeature([], tf.float32),
'label': tf.io.FixedLenFeature([], tf.int64),
}
parsed_example = tf.io.parse_single_example(serialized_example, feature_description)
features = {'feature': parsed_example['feature']}
label = parsed_example['label']
return features, label
train_dataset = train_dataset.map(_parse_fn).batch(32)
eval_dataset = eval_dataset.map(_parse_fn).batch(32)
# 모델 학습
model.fit(train_dataset, epochs=10, validation_data=eval_dataset)
# 모델 저장
model.save(fn_args.serving_model_dir, save_format='tf')
"""
from tfx.components import Trainer
trainer = Trainer(
module_file=preprocessing_file_path, # 모델 생성 함수도 동일 파일에 정의됨
examples=transform.outputs['transformed_examples'], # 변환된 예시를 사용
schema=schema_gen.outputs['schema']
)
import os
from tfx.components import Pusher
from tfx.proto import pusher_pb2
pusher = Pusher(
model=trainer.outputs['model'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=os.path.join({pipeline 루트 경로}, 'serving_model')
)
)
)
from tfx.orchestration import pipeline
from tfx.orchestration.local.local_dag_runner import LocalDagRunner
from tfx.orchestration.metadata import sqlite_metadata_connection_config
pipeline_root = os.path.join(os.getcwd(), 'pipeline') # 파이프라인 루트 경로
data_root = os.path.join(pipeline_root, 'data') # 데이터가 저장될 경로
metadata_path = os.path.join(pipeline_root, 'metadata.db') # sqlite 메타데이터 db 경로
# TFX 파이프라인 구성
def _create_pipeline(pipeline_root, data_root, metadata_path):
components = [
example_gen,
statistics_gen, # StatisticsGen 추가
schema_gen, # SchemaGen 추가
transform, # Transform 추가
trainer,
pusher
]
return pipeline.Pipeline(
pipeline_name='tfx_pipeline_name',
pipeline_root=pipeline_root,
components=components,
enable_cache=True,
metadata_connection_config=sqlite_metadata_connection_config(metadata_path), # 메타데이터 DB 설정
beam_pipeline_args=None
)
# LocalDagRunner로 파이프라인 실행
LocalDagRunner().run(_create_pipeline(pipeline_root, data_root, metadata_path))
kubeflow 환경에서 TFX를 통해 파이프라인 관리를 하는 포스트는 추후에..