hydra사용법 정리

용권순·2023년 12월 27일
1

Machine_Learning

목록 보기
5/5

hydra 사용법

개요

github를 보면 yaml 파일이 참 많이 보이는데, 사용할 때 마다 헷갈려서 정리를 해본다.

config.yaml

보통 parameter를 설정할 때, main 파일에

EPOCH_COUNT = 20
LR= 5e-5
BATCH_SIZE=128
LOG_PATH = "./runs"

다음과 같이 설정하곤 한다. 이는 변수를 바꿀 때마다 수정하기가 영 번거롭고, parameter가 이곳 저곳에 섞여 있으면 자주 까먹기 때문에 직관적으로, 쉽게 바꿔보자.
위에서 parameter snippet을 yaml으로 저장해보자.

conf라는 폴더에 config.yaml이라는 파일을 생성한다.
그 다음 위에서 정의한 파라메터를

params:
	epoch_count: 20
    lr: 5e-5
    batch_size : 128
#소문자로 작성하는 것이 관습인 것 같다.

yaml을 만들었다면, main에서 불러와야한다.

import hydra를 main.py에서 불러주자.
그리고 다음과 같이 magic function을 정의한다.

'''
main.py
'''
import hydra
#yaml 파일이 있는 위치 : conf, yaml이름 config
@hydra.main(config_path = "conf", config_name="config")
def main(cfg):# hydra를 사용하기 위해서 main function에 config 인자를 넣어줘야함 
 	~
    ~
    return

cfg에는 뭐가 들어올까? 아까 yaml에서 읽은 parameter를 dictionary 형태로 읽어온다.
즉,

{'params':{"epoch_count": 20, "lr": 5e-5, "batch_size":128}}

data dir도 yaml에 설정할 수 있는데, 여기서 yaml에는 string을 지정할 필요가 없다. 즉,

paths: 
	log: ./runs # from LOG_DIR = "./runs" 
    data: ${hydra:runtime.cwd}/../data/raw

하지만 경로나 값에 특수 문자나 YAML 구문과 충돌할 수 있는 문자가 포함되어 있는 경우에는 따옴표를 사용하는 것이 좋다고 한다. 여기서 이상한 부분이 있는데,

  • ${hydra:runtime.cwd}의 의미
    ${hydra:runtime.cwd}는 Hydra 구성 시스템에서 제공하는 특별한 구문. 여기서 hydra:runtime.cwd는 현재 작업 디렉토리(Current Working Directory)의 절대 경로를 참조하는 것을 나타낸다.

이제 우리는 다음과 같은 config.yaml이 존재한다고 가정해보자.

params:
	epoch_count: 20
    lr: 5e-5
    batch_size : 128
Files:
	train_data: ../../../train_data.gz
	train_labels: ../../../train_labels.gz
	valid_data: ../../../valid_data.gz
	valid_labels: ../../../valid_labels.gz

paths: 
	log: ./runs # from LOG_DIR = "./runs" 
    data: ${hydra:runtime.cwd}/../data/raw

config.py

이제 yaml로 parameter를 설정했다고 하자. 그럼 cfg를 그냥 사용하면 되는 걸까?
yaml는 parameter setting을 지정하는 좋은 방법이지만, main에서 직접 접근할 수는 없다. (변수 저장 모음인 것이지, 변수가 아니기에 형변환이 필요한 느낌)
main에서 cfg로 직접 접근하기 위해서 config.py라는 것을 생성해 줘야한다. dataclasses라는 package을 사용해서 직접 접근이 가능하도록 설정해보자.

from dataclasses import dataclass

#config.yaml에 들어있는 설정들을 직접 지정해줘야한다.
@dataclass
class Paths:
	log:str
    data: str

@dataclass 
class Files:
	train_data:str
    train_labels: str
    test_data: str
    test_labels: str

@dataclass 
class Params: 
	epoch_count : int
    lr : float
	batch_size : int

## 위에서 타입 지정을 해줬다면, 
#사용할 config를 dictionary로 다시한번 묶어서 저장한다. 
#예를 들어 MNISTConfig를 만든다고 하면, 
@dataclass 
class MNISTConfig:
	paths: Paths
    files: Files
    params: Params

config.py를 사용해서 타입 지정을 해줬다. 이제 main에서 직접 사용해보자.

main문 사용법

"""main.py"""
from hydra.core.config_store import ConfigStore
cs = ConfigStore.instance() # 접근을 위한 객체 
cs.store(name="mnist_config", node=MNISTConfig)#접근할 config 객체 선정

@hydra.main(config_path = "conf", config_name="config")
def main(cfg:MNISTConfig):# cfg를 MNISTConfig를 사용할 것임을 명시
 	model = nn.Linear()
    #cfg를 사용해서 yaml에 있는 설정을 그대로 가지고 올 수 있게 되었다!
    optimizer = torch.optim.Adam(model.parameters(), lr = cfg.params.lr)
    test_loader = create_dataloader(cfg.params.batch_size, 
    return

좋은 dataloader 스니펫

def create_dataloader(
	root_path : str, data_file:str, label_path: str
):
	data_path  = Path(f"{root_path}/{data_file}")
	label_path = Path(f"{root_path}/{label_path}")

	return

ps 다양한 dataset?

config.yaml에는 공통적인 설정을 적어두고,
dataset에 해당하는 yaml을 따로 설정해두면 편하다.

"""mnist.yaml"""
train_data:str
train_labels: str
test_data: str
test_labels: str
#########
#########
"""config.yaml"""
defaults:
	- files: mnist # mnist.yaml을 import 
    - _self_ # 나머지 설정은 현재 파일(config.yaml)을 사용하겠다. 
params:
	epoch_count: 20
    lr: 5e-5
    batch_size : 128
paths: 
	log: ./runs # from LOG_DIR = "./runs" 
    data: ${hydra:runtime.cwd}/../data/raw

dataset을 바꿀 때 마다 defaults 설정을 바꾸면 된다.

profile
평범한 대학원생입니다...

0개의 댓글