[python] argparse

whybe-choi·2024년 4월 26일
0

python

목록 보기
1/1
python train.py --epochs 50 --batch_size 64 --save_dir ./models

다른 사람들이 작성한 딥러닝 모델 학습 코드를 살펴보면 argparse가 자주 등장하며 위와 같이 스크립트 형태로 학습을 진행하곤 한다. argparse 모듈은 CLI 환경에서 파이썬 스크립트를 호출할 때 인자값(argument)을 다르게 줌으로써 다른 동작을 하고 싶은 경우에 유용하게 사용할 수 있다. 주로 batch size, learning rate 등의 hyper-parameter나 모델의 저장 경로, 데이터 셋의 경로 등을 지정하기 위해 자주 사용된다.


1. argparse 기본 준비

import argparse

parser = argparse.AugumentParser()
parser.add_augument("--seed", type=int, default=42)
args = parser.parse_arg()

argparse는 기본적으로 위와 같은 구성으로 이루어져있다.

  • ArgumentParser() : argparse를 사용하기 위해서는 먼저 ArgumenParser() 객체를 생성해야 한다.
  • add_argument(): add_gument()를 이용하여 CLI에서 사용할 인자에 대한 정보를 추가한다.
  • parse_arg() : parse_arg()를 통해 CLI에서 사용할 인자를 파싱한다.

위의 코드에서 add_argument를 통해 추가한 인자를 parse_arg를 사용하면 args.seed와 같은 형태로 인자로 전달한 값을 사용할 수 있다.

2. add_argument()

다음과 같이 argparse를 사용하는 코드가 있을 때 주로 사용되는 옵션은 다음과 같다.

import argparse

parser = argparse.ArgumentParser(description='Argparse Tutorial')

parser.add_argument('-n', '--number', type=int, help='an interger for printing repeatably', default=2)
parser.add_argument('-s', '--string', type=str, help='a string for printing repeatably', required=True)

args = parser.parse_args()

for i in range(args.number):
    print(f"print number {i} {args.string}")
  • name or flags: 옵션 문자열의 이름이나 리스트 (ex. foo 또는 -f, --foo)
  • action: command-line에서 이 인자가 발견될 때 수행할 액션의 기본형
    • store_true/ store_false: 해당 옵션을 사용할 경우에 True/False를 저장
  • nargs: 사용하고자 하는 인자의 수를 지정하고자 할 때 사용
    • * : 0개 이상
    • + : 1개 이상
    • ? : 0 또는 1개
  • type: 인자의 데이터 타입 (기본값은 str)
  • default : 인자가 입력되지 않았을 때 사용할 기본값
  • required : 해당 인자가 필수인지의 여부 (기본값은 False)
  • choices: 인자로 허용되는 값의 목록
  • help: -h--help를 통해 출력될 인자에 대한 설명
  • const : 일부 actionnargs를 선택할 때 필요한 상숫값

📌 인자의 이름에는 -와 _를 쓸 수 있다. 단, python 기본 문법은 변수명에 -를 허용하지 않기 때문에 인자의 이름에 -가 들어갔다면 args의 인자로 접근하기 위해서는 를 _로 바꿔 주어야한다.

3. argparse 활용하기

모델 학습과 관련하여 아래와 같은 형태로 주로 사용되는 것을 확인할 수 있으므로 아래 코드가 무엇을 의미하는지 분석해보자.

import argparse

if __name__=='__main__':
    parser = ArgumentParser()
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--model_path", type=str, default="EleutherAI/polyglot-ko-1.3b")
    parser.add_argument("--tokenizer_path", type=str)
    parser.add_argument("--dataset_name", type=str, default="heegyu/korquad-chat-v1")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--n_epoch", type=int, default=5)
    parser.add_argument("--lr", type=float, default=1e-5)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--weight_decay", type=float, default=0.1)
    parser.add_argument("--warmpup_ratio", type=float, default=0.06)
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--save_interval", type=int, default=500)
    parser.add_argument("--verbose_interval", type=int, default=20)
    parser.add_argument("--save_dir", type=str, default="outputs")
    parser.add_argument("--tensorboard_log_interval", type=int, default=20)
    parser.add_argument("--tensorboard_path", type=str, default="./tensorboard")

    args = parser.parse_args()
    set_seed(args.seed)

    main(args)

reference

0개의 댓글