오랜만에 작성해보는 연구 기록 일지..!
베이스 코드 골라서 분석하고 수정해야 하는데, hydra라는 새로운 어노테이션?이라고 해야하나 라이브러리? 모듈? 을 접하게 되었다!
마침 이번주에 같이 연구하는 랩장님이 해외 학회로 출장을 가서 ㅎㅎ약간? 여유로울 때 이런거 공부도 해보고 기록도 해보려고 한다
AI 연구를 하고 코드를 구현해본 사람들이라면 알거다.
config가 꼭 필요하다.
def run():
config = yaml.safe_load("file.yaml")
data = Dataset(config.data)
model = Model(config.model)
이런식으로 config를 불러와서 데이터나 모델을 설정해준다.
근데 이 위의 코드는 class의 attribute 방식으로 사용하는 사례이다.
하지만 이 외에도 config 파일을 config를 딕셔너리 형태로 불러와서 사용할 수도 있을 것이다.
이때 이 2가지 방식을 둘다 사용할 수 있게 해주는 것이 hydra이다.
우선 hydra를 사용하려면 가장 먼저 호출되는 함수에 @hydra.main()을 달아주어야 한다.
내가 작성하고 있는 코드에서는 train()메서드일 것이다.
따라서 아래와 같이 작성된다.
@hydra.main(version_base = None, config_path = 'configs')
def train():
...
if __name__ =="__main__":
train()
이렇게 작성이 되면 train()메서드는 hydra에 감염된다고 한다.
그럼 hydra.main의 설계 안에 있는 함수로 train() 메서드가 바꿔치기 된다.
이 바꿔치기된 함수에서 첫 번째 인자가 yaml파일의 configuration이다.
이 configuration 파일은 터미널의 실행문에서 지정된다.
python main.py --config-name train
이 명령어가 실행되면 @hydra.main에서 입력된 config_path가 폴더 이름, --confing-name 인자로 전달받은 train이 yaml파일로 지정되어 최종적으로 configs/train.yaml파일 정보가 config에 반영된다.
내가 사용할 코드는 아래와 같이 구현되어 있었다.
@hydra.main(
version_base=None,
config_path="../config",
config_name="main",
)
def train(cfg_dict: DictConfig):
...
if __name__ == "__main__":
warnings.filterwarnings("ignore")
torch.set_float32_matmul_precision('high')
train()
그리고 ../config/main.yaml 파일이 아래와 같이 구성되어 있었다.
defaults:
- dataset: re10k
- optional dataset/view_sampler_dataset_specific_config: ${dataset/view_sampler}_${dataset}
- model/encoder: costvolume
- model/decoder: splatting_cuda
- loss: [mse]
wandb:
project: mvsplat
entity: placeholder
name: placeholder
mode: disabled
id: null
mode: train
dataset:
overfit_to_scene: null
data_loader:
# Avoid having to spin up new processes to print out visualizations.
train:
num_workers: 10
persistent_workers: true
batch_size: 4
seed: 1234
test:
num_workers: 4
persistent_workers: false
batch_size: 1
seed: 2345
val:
num_workers: 1
persistent_workers: true
batch_size: 1
seed: 3456
optimizer:
lr: 2.e-4
warm_up_steps: 2000
cosine_lr: true
checkpointing:
load: null
every_n_train_steps: 20000 # 5000
save_top_k: -1
pretrained_model: null
resume: true
train:
depth_mode: null
extended_visualization: false
print_log_every_n_steps: 1
test:
output_path: outputs/test
compute_scores: false
eval_time_skip_steps: 0
save_image: true
save_video: false
seed: 111123
trainer:
max_steps: -1
val_check_interval: 0.5
gradient_clip_val: 0.5
num_sanity_val_steps: 2
num_nodes: 1
output_dir: null
그래서 만약 아래와 같은 shell 스크립트를 실행한다면 @hydra.main에 default로 설정해놓은 경로인 ../config/main.yaml 파일을 기본적으로 불러오게 될 것이다.
(이 shell 스크립트에서 추가적으로 config 파일의 경로를 덮어쓰지 않기 때문에)
그리고 train()메서드에 인자로 들어와있는 cfg_dict가 이 main.yaml의 파일 정보를 다 담고 있을 것이다.
python -m src.main +experiment=re10k \
checkpointing.load=checkpoints/re10k.ckpt \
mode=test \
dataset/view_sampler=evaluation \
test.compute_scores=true
config 불러오기 코드가 없어도 된다.
config파일에 dictionary로도 접근할 수 있고, 클래스 attribute로도 접근할 수 있다.
config파일에 특정 attribute가 있는지 확인하기 좋음 (if config.model.get("pretrained"): 이런 느낌으로다가 ~
---😊
여기까진 hydra의 기본 개념!
활용하거나 코드 보면서 새로 알아야 할 것들이 계속 나오는데 이건 GPT와 함께 대화하면서 코드짜보는걸로..~~아자자⭐😊