python train.py --epochs 50 --batch_size 64 --save_dir ./models
다른 사람들이 작성한 딥러닝 모델 학습 코드를 살펴보면 argparse
가 자주 등장하며 위와 같이 스크립트 형태로 학습을 진행하곤 한다. argparse
모듈은 CLI 환경에서 파이썬 스크립트를 호출할 때 인자값(argument)을 다르게 줌으로써 다른 동작을 하고 싶은 경우에 유용하게 사용할 수 있다. 주로 batch size, learning rate 등의 hyper-parameter나 모델의 저장 경로, 데이터 셋의 경로 등을 지정하기 위해 자주 사용된다.
argparse
기본 준비import argparse
parser = argparse.AugumentParser()
parser.add_augument("--seed", type=int, default=42)
args = parser.parse_arg()
argparse
는 기본적으로 위와 같은 구성으로 이루어져있다.
ArgumentParser()
: argparse
를 사용하기 위해서는 먼저 ArgumenParser()
객체를 생성해야 한다. add_argument()
: add_gument()
를 이용하여 CLI에서 사용할 인자에 대한 정보를 추가한다.parse_arg()
: parse_arg()
를 통해 CLI에서 사용할 인자를 파싱한다.위의 코드에서 add_argument
를 통해 추가한 인자를 parse_arg
를 사용하면 args.seed
와 같은 형태로 인자로 전달한 값을 사용할 수 있다.
add_argument()
다음과 같이 argparse
를 사용하는 코드가 있을 때 주로 사용되는 옵션은 다음과 같다.
import argparse
parser = argparse.ArgumentParser(description='Argparse Tutorial')
parser.add_argument('-n', '--number', type=int, help='an interger for printing repeatably', default=2)
parser.add_argument('-s', '--string', type=str, help='a string for printing repeatably', required=True)
args = parser.parse_args()
for i in range(args.number):
print(f"print number {i} {args.string}")
name or flags
: 옵션 문자열의 이름이나 리스트 (ex. foo
또는 -f
, --foo
)action
: command-line에서 이 인자가 발견될 때 수행할 액션의 기본형store_true
/ store_false
: 해당 옵션을 사용할 경우에 True
/False
를 저장nargs
: 사용하고자 하는 인자의 수를 지정하고자 할 때 사용*
: 0개 이상+
: 1개 이상?
: 0 또는 1개type
: 인자의 데이터 타입 (기본값은 str
)default
: 인자가 입력되지 않았을 때 사용할 기본값required
: 해당 인자가 필수인지의 여부 (기본값은 False
)choices
: 인자로 허용되는 값의 목록help
: -h
나 --help
를 통해 출력될 인자에 대한 설명const
: 일부 action
및 nargs
를 선택할 때 필요한 상숫값📌 인자의 이름에는 -와 _를 쓸 수 있다. 단, python 기본 문법은 변수명에 -를 허용하지 않기 때문에 인자의 이름에 -가 들어갔다면
args
의 인자로 접근하기 위해서는 를 _로 바꿔 주어야한다.
argparse
활용하기모델 학습과 관련하여 아래와 같은 형태로 주로 사용되는 것을 확인할 수 있으므로 아래 코드가 무엇을 의미하는지 분석해보자.
import argparse
if __name__=='__main__':
parser = ArgumentParser()
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--model_path", type=str, default="EleutherAI/polyglot-ko-1.3b")
parser.add_argument("--tokenizer_path", type=str)
parser.add_argument("--dataset_name", type=str, default="heegyu/korquad-chat-v1")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--n_epoch", type=int, default=5)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--weight_decay", type=float, default=0.1)
parser.add_argument("--warmpup_ratio", type=float, default=0.06)
parser.add_argument("--max_length", type=int, default=512)
parser.add_argument("--save_interval", type=int, default=500)
parser.add_argument("--verbose_interval", type=int, default=20)
parser.add_argument("--save_dir", type=str, default="outputs")
parser.add_argument("--tensorboard_log_interval", type=int, default=20)
parser.add_argument("--tensorboard_path", type=str, default="./tensorboard")
args = parser.parse_args()
set_seed(args.seed)
main(args)