이번 포스트에서는 Github에서 유명한 PL 템플릿을 분석하려고 합니다.
템플릿 원본 코드는 lightning-hydra-template에서 볼 수 있습니다.
├── .github <- Github Actions workflows
│
├── configs <- Hydra configs
│ ├── callbacks <- Callbacks configs
│ ├── data <- Data configs
│ ├── debug <- Debugging configs
│ ├── experiment <- Experiment configs
│ ├── extras <- Extra utilities configs
│ ├── hparams_search <- Hyperparameter search configs
│ ├── hydra <- Hydra configs
│ ├── local <- Local configs
│ ├── logger <- Logger configs
│ ├── model <- Model configs
│ ├── paths <- Project paths configs
│ ├── trainer <- Trainer configs
│ │
│ ├── eval.yaml <- Main config for evaluation
│ └── train.yaml <- Main config for training
│
├── data <- Project data
│
├── logs <- Logs generated by hydra and lightning loggers
│
├── notebooks <- Jupyter notebooks. Naming convention is a number (for ordering),
│ the creator's initials, and a short `-` delimited description,
│ e.g. `1.0-jqp-initial-data-exploration.ipynb`.
│
├── scripts <- Shell scripts
│
├── src <- Source code
│ ├── data <- Data scripts
│ ├── models <- Model scripts
│ ├── utils <- Utility scripts
│ │
│ ├── eval.py <- Run evaluation
│ └── train.py <- Run training
│
├── tests <- Tests of any kind
│
├── .env.example <- Example of file for storing private environment variables
├── .gitignore <- List of files ignored by git
├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
├── .project-root <- File for inferring the position of project root directory
├── environment.yaml <- File for installing conda environment
├── Makefile <- Makefile with commands like `make train` or `make test`
├── pyproject.toml <- Configuration options for testing and linting
├── requirements.txt <- File for installing python dependencies
├── setup.py <- File for installing project as a package
└── README.md
해당 구조를 간략하게 설명하면 configs 폴더에 존재하는 hydra와 src 폴더에 존재하는 코드로 구성되어 있습니다.
train.yaml과 train.py 즉 hydra와 src가 어떻게 상호작용하는지 보고자 합니다.
train.yaml
# @package _global_
# specify here default configuration
# order of defaults determines the order in which configs override each other
defaults:
- _self_
- data: mnist.yaml
- model: mnist.yaml
- callbacks: default.yaml
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default.yaml
- paths: default.yaml
- extras: default.yaml
- hydra: default.yaml
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: null
# config for hyperparameter optimization
- hparams_search: null
# optional local config for machine/user specific settings
# it's optional since it doesn't need to exist and is excluded from version control
- optional local: default.yaml
# debugging config (enable through command line, e.g. `python train.py debug=default)
- debug: null
# task name, determines output directory path
task_name: "train"
# tags to help you identify your experiments
# you can overwrite this in experiment configs
# overwrite from command line with `python train.py tags="[first_tag, second_tag]"`
tags: ["dev"]
# set False to skip model training
train: True
# evaluate on test set, using best model weights achieved during training
# lightning chooses best weights based on the metric specified in checkpoint callback
test: True
# compile model for faster training with pytorch 2.0
compile: False
# simply provide checkpoint path to resume training
ckpt_path: null
# seed for random number generators in pytorch, numpy and python.random
seed: null
train.py
@utils.task_wrapper
def train(cfg: DictConfig) -> Tuple[dict, dict]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
Args:
cfg (DictConfig): Configuration composed by Hydra.
Returns:
Tuple[dict, dict]: Dict with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
utils.log_hyperparameters(object_dict)
if cfg.get("compile"):
log.info("Compiling model!")
model = torch.compile(model)
if cfg.get("train"):
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
먼저 train 함수의 경우 데코레이터로 감싸져 있다.
@utils.task_wrapper의 경우 train이 실패할 경우 log를 만들어주는 기능을 한다.
이후 hydra.utils.instantiate를 사용하는데 이는 hydra config를 통해 인스턴스를 만들어주는 함수이다.
>>> hydra.utils.instantiate(cfg.model)
# train.yaml
defaults:
-model: mnist.yaml
# model/mnist.yaml
_target_: src.models.mnist_module.MNISTLitModule
optimizer:
_target_: torch.optim.Adam
...
다음과 같이 train.yaml -> model/mnist로 접근하여 정의된 model config를 통해 인스터스를 생성한다.
해당 방식으로 인스턴스를 생성 후 train 혹은 test를 진행한다.
이 때 train과 test는 trainer에 의해 진행되는데 cfg를 통해 얻은 model에 정의된 방식으로 진행된다.
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
utils.extras(cfg)
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = utils.get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
main()
main에서는 hydra를 사용하기 위한 데코레이터인 @hydra.main을 사용하여 config를 정의합니다.
이후 utils.extras를 통해 옵션(ignore_warnings 등)을 설정하고 train을 실행합니다.
이후 utils.get_metric_value를 통해 dict를 해체하고 결과값을 return합니다.
정리하자면 해당 템플릿을 사용하기 위해는 다음과 같은 과정을 거치면 됩니다.
1. src/models/에 원하는 모델을 추가한다.
2. src/data/에 학습시킬 데이터에 맞는 datamodule을 정의한다.
3. experiment 혹은 train config 파일을 변경한다.
4. python src/train.py experiment=experiment_name.yaml
혹은 python src/train.py
를 실행한다.
이외에도 hparams_search 등 다양한 기능이 존재합니다.