[BoostCamp AI Tech / Day 11, PytTorch] 3강

newbie·2021년 8월 17일
0

1. 프로젝트 폴더는 train.py로 시작

2. if name == 'main': if 절 확인

  • -c : config, 실행하는 설정파일
    • config.json 파일을 열어보면 해당 기능이 어떤 역할을 수행하는지 알 수 있음
    • 다양한 하이퍼파라미터 값들을 지정해놓음
  • -r : 이전에 이력을 이어서 실행할 것인지
  • -d : device는 cpu, gpu중 어떤 것을 선택할지
if __name__ == '__main__':
    args = argparse.ArgumentParser(description='PyTorch Template')
    args.add_a rgument('-c', '--config', default=None, type=str,
                      help='config file path (default: None)') #실행하는 설정파일을 불러옴
    args.add_argument('-r', '--resume', default=None, type=str,
                      help='path to latest checkpoint (default: None)') #이전에 실행했던 것을 이어서 실행할 것인지
    args.add_argument('-d', '--device', default=None, type=str,
                      help='indices of GPUs to enable (default: all)') #device는 cpu or gpu 선택

3. 2번에서 각각 불러온 것을 namedtuple을 만든 다음 configParser로 넘김

    CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
    options = [
        CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
        CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
    ]
    config = ConfigParser.from_args(args, options)
    main(config)

4. ConfigParser.from_args의 factory pattern으로 OOP 생성

OOP에서 factory pattern은 아래처럼 무언가를 넣어주면 그 아래 객체를 생성해주는 패턴이다.

보통은 "class명.함수" 이런 형태로 적어준 경우 함수에 재료(args)를 넣어주면 그거에 대해 객체를 반환해주는 패턴
따라서, 이 팩토리패턴을 통해 args를 불러온 것을 다 해석을 진행
from_args 코드는 아래와 같고 여기서 중요한 것은 결국 config = read_json(cfg_fname)

    @classmethod
    def from_args(cls, args, options=''):
        """
        Initialize this class from some cli arguments. Used in train, test.
        """
        for opt in options:
            args.add_argument(*opt.flags, default=None, type=opt.type)
        if not isinstance(args, tuple):
            args = args.parse_args()

        if args.device is not None:
            os.environ["CUDA_VISIBLE_DEVICES"] = args.device
        if args.resume is not None:
            resume = Path(args.resume)
            cfg_fname = resume.parent / 'config.json'
        else:
            msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
            assert args.config is not None, msg_no_cfg
            resume = None
            cfg_fname = Path(args.config)
        
        config = read_json(cfg_fname)
        if args.config and resume:
            # update new config for fine-tuning
            config.update(read_json(args.config))

        # parse custom cli options into dictionary
        modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
        return cls(config, resume, modification)

5. utils-util.py의 read_json 함수 확인

데이터를 불러와서 OrderedDict 형태로 데이터 형태로 변환해줌

def read_json(fname):
    fname = Path(fname)
    with fname.open('rt') as handle:
        return json.load(handle, object_hook=OrderedDict)

6. 객체 생성

read_json으로 불러온 orderdict type의 데이터를 config 변수에 할당, 대부분의 설정 정보는 config에 저장되고
cls 객체의 argument로 지정되어 return된다.
4.에 작성된 코드를 보면 from_args는 cls 객체로(@classmethod를 보고 판단할 수 있고, 또한 self 대신 cls가 입력되어 있다. 기존 method 타입(self)과 차이는 해당 링크를 찾아보자

config = read_json(cfg_fname)
        if args.config and resume:
            # update new config for fine-tuning
            config.update(read_json(args.config))

        # parse custom cli options into dictionary
        modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
        return cls(config, resume, modification)

7. main() 실행

그리고 다시 반환된 객체를 train.py 파일의 config 객체에 할당 후 main()의 변수로 지정하여 main()을 시작하고

config = ConfigParser.from_args(args, options)
    main(config)

아래 main()함수가 시작된다.
main은 config file로 실행이 되는데, config 파일은 실질적으로 말하면 dict가 아니라 cls로 리턴했기 때문에 from_args 클래스 자체이다.

def main(config):
    logger = config.get_logger('train')

    # setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)
    valid_data_loader = data_loader.split_validation()

    # build model architecture, then print to console
    model = config.init_obj('arch', module_arch)
    logger.info(model)

    # prepare for (multi-device) GPU training
    device, device_ids = prepare_device(config['n_gpu'])
    model = model.to(device)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    # get function handles of loss and metrics
    criterion = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
    lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

    trainer = Trainer(model, criterion, metrics, optimizer,
                      config=config,
                      device=device,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()

8. cls를 dict 형태로 데이터를 불러올 수 있도록 한 getitem

main() 실행 관련해서 언급한 7번에서 말했듯이 config는 실제로는 dict가 아닌 클래스 자체이다.
하지만 재미난 건 parse_config.py 모듈에 있는 class ConfigParser를 보면,
아래와 같이 getitem method를 활용하여 데이터를 ordered dict 형태로 접근할 수 있게 해줬다.
아래 코드를 풀어보면 getitem에 config의 name을 넣어주면 설정값을 확인할 수 있다.

    def __getitem__(self, name):
        """Access items like ordinary dict."""
        return self.config[name]

9. 다시 main 함수 코드 풀이

main 함수를 실행시키면 첫 줄은 get_logger로, parse_config.py에서 해당 함수를 찾아보면,
우선 verbosity는 log_level인데 0=warning, 1=info, 2=debug 시에 결과를 출력한다.
따라서 verbosity값이 log_levels에 없으면 msg_verbosity가 출력된다.
그리고 이상이 없으면 해당 코드에서 적힌 수준으로 시작되는데,
train이 적히면 train level에서 log데이터를 얻고, test면 test level에서 log data를 얻는다.

def main(config):
    logger = config.get_logger('train')
--------------------------------------------------------------------
def get_logger(self, name, verbosity=2):
    msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
    assert verbosity in self.log_levels, msg_verbosity
    logger = logging.getLogger(name)
    logger.setLevel(self.log_levels[verbosity])
    return logger

그 다음 해당 줄이 실행되는데.. init_obj가 무엇일까
init_obj 함수의 attribute를 보면 name과 module을 받아왔는데, 여기서 agument로 지정한 module_data는 train.py 파일에서 필요한 패키지 및 모듈을 import 시 지정한 alias로 다음과 같이 data_loader.data_loaders 이다.
즉, init_obj는 MNIST 데이터를 불러오기 위한 모듈 자체를 argument로 받는다.

import data_loader.data_loaders as module_data

# setup data_loader instances
    data_loader = config.init_obj('data_loader', module_data)
--------------------------------------------------
def init_obj(self, name, module, *args, **kwargs):
    """
    Finds a function handle with the name given as 'type' in config, and returns the
    instance initialized with corresponding arguments given.

    `object = config.init_obj('name', module, a, b=1)`
    is equivalent to
    `object = module.name(a, b=1)`
    """
    module_name = self[name]['type']
    module_args = dict(self[name]['args'])
    assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
    module_args.update(kwargs)
    return getattr(module, module_name)(*args, **module_args)

여기서 self[name]은 class자체인 config이며 getitem을 통해 name과 맞는 config데이터에서 type을 불러오는데, 여기서 name은 data_loader가 된다.
따라서 해당 config 정보를 보면, 다음과 같이 ministdataloader에 대한 정보임을 알 수 있다.
다시 init_obj를 보면, kwargs가 현재 지정된 config 내에 없다면 에러를 출력하고 만약 있다면,
module_arg에 kwargs를 update한다.
그리고 마지막으로 getattr()을 통해서 해당 모듈, 즉 data_loders 모듈에서 module_name(MnistDataLoader)와 같은 객체의 속성값을 가져온다.

module_args

"data_loader": {
    "type": "MnistDataLoader",
    "args":{
        "data_dir": "data/",
        "batch_size": 128,
        "shuffle": true,
        "validation_split": 0.1,
        "num_workers": 2
    }
}

여기서 잠깐!
getattr를 왜 쓸까?
사실 어떤 모듈을 부르거나 클래스를 부를 때 직접 코드안에 하드코딩을 해야 하는데 그렇게 하기엔 비효율적일 떄가 있다.
예를 들어 mnist1data 를 불렀는데 이번엔 mnist2data를 부르게 되는 것과 같이 일일이 다 하드코딩 할 순 없는 노릇이다.
따라서, config json file만 바꿔서 사용할 수 있게 해준 것이 getattr 방식이 지원을 해준다.

그리고 data_loader와 동일하게 model, loss, metric, trainable_params, optimizers, lr_scheduler도 똑같은 방법으로 config 정보를 불러온다.

# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()

# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)

# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
    model = torch.nn.DataParallel(model, device_ids=device_ids)

# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]
    
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)

train!!

마지막으로 trainer 모듈의 Trainer 클래스 객체를 생성하는데, 이 Trainer 클래스에서 지금까지 불러온 config 정보를 토대로 다양한 값을 저장하고, Trainer의 method로 지정된 train을 실행시키면 학습이 시작되는 것을 알 수 있다.

trainer = Trainer(model, criterion, metrics, optimizer,
                      config=config,
                      device=device,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler)

    trainer.train()
profile
DL, NLP Engineer to be....

0개의 댓글