if __name__ == '__main__':
args = argparse.ArgumentParser(description='PyTorch Template')
args.add_a rgument('-c', '--config', default=None, type=str,
help='config file path (default: None)') #실행하는 설정파일을 불러옴
args.add_argument('-r', '--resume', default=None, type=str,
help='path to latest checkpoint (default: None)') #이전에 실행했던 것을 이어서 실행할 것인지
args.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)') #device는 cpu or gpu 선택
CustomArgs = collections.namedtuple('CustomArgs', 'flags type target')
options = [
CustomArgs(['--lr', '--learning_rate'], type=float, target='optimizer;args;lr'),
CustomArgs(['--bs', '--batch_size'], type=int, target='data_loader;args;batch_size')
]
config = ConfigParser.from_args(args, options)
main(config)
OOP에서 factory pattern은 아래처럼 무언가를 넣어주면 그 아래 객체를 생성해주는 패턴이다.

보통은 "class명.함수" 이런 형태로 적어준 경우 함수에 재료(args)를 넣어주면 그거에 대해 객체를 반환해주는 패턴
따라서, 이 팩토리패턴을 통해 args를 불러온 것을 다 해석을 진행
from_args 코드는 아래와 같고 여기서 중요한 것은 결국 config = read_json(cfg_fname)
@classmethod
def from_args(cls, args, options=''):
"""
Initialize this class from some cli arguments. Used in train, test.
"""
for opt in options:
args.add_argument(*opt.flags, default=None, type=opt.type)
if not isinstance(args, tuple):
args = args.parse_args()
if args.device is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
if args.resume is not None:
resume = Path(args.resume)
cfg_fname = resume.parent / 'config.json'
else:
msg_no_cfg = "Configuration file need to be specified. Add '-c config.json', for example."
assert args.config is not None, msg_no_cfg
resume = None
cfg_fname = Path(args.config)
config = read_json(cfg_fname)
if args.config and resume:
# update new config for fine-tuning
config.update(read_json(args.config))
# parse custom cli options into dictionary
modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
return cls(config, resume, modification)
데이터를 불러와서 OrderedDict 형태로 데이터 형태로 변환해줌
def read_json(fname):
fname = Path(fname)
with fname.open('rt') as handle:
return json.load(handle, object_hook=OrderedDict)
read_json으로 불러온 orderdict type의 데이터를 config 변수에 할당, 대부분의 설정 정보는 config에 저장되고
cls 객체의 argument로 지정되어 return된다.
4.에 작성된 코드를 보면 from_args는 cls 객체로(@classmethod를 보고 판단할 수 있고, 또한 self 대신 cls가 입력되어 있다. 기존 method 타입(self)과 차이는 해당 링크를 찾아보자
config = read_json(cfg_fname)
if args.config and resume:
# update new config for fine-tuning
config.update(read_json(args.config))
# parse custom cli options into dictionary
modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
return cls(config, resume, modification)
그리고 다시 반환된 객체를 train.py 파일의 config 객체에 할당 후 main()의 변수로 지정하여 main()을 시작하고
config = ConfigParser.from_args(args, options)
main(config)
아래 main()함수가 시작된다.
main은 config file로 실행이 되는데, config 파일은 실질적으로 말하면 dict가 아니라 cls로 리턴했기 때문에 from_args 클래스 자체이다.
def main(config):
logger = config.get_logger('train')
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)
# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)
# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)
trainer.train()
main() 실행 관련해서 언급한 7번에서 말했듯이 config는 실제로는 dict가 아닌 클래스 자체이다.
하지만 재미난 건 parse_config.py 모듈에 있는 class ConfigParser를 보면,
아래와 같이 getitem method를 활용하여 데이터를 ordered dict 형태로 접근할 수 있게 해줬다.
아래 코드를 풀어보면 getitem에 config의 name을 넣어주면 설정값을 확인할 수 있다.
def __getitem__(self, name):
"""Access items like ordinary dict."""
return self.config[name]
main 함수를 실행시키면 첫 줄은 get_logger로, parse_config.py에서 해당 함수를 찾아보면,
우선 verbosity는 log_level인데 0=warning, 1=info, 2=debug 시에 결과를 출력한다.
따라서 verbosity값이 log_levels에 없으면 msg_verbosity가 출력된다.
그리고 이상이 없으면 해당 코드에서 적힌 수준으로 시작되는데,
train이 적히면 train level에서 log데이터를 얻고, test면 test level에서 log data를 얻는다.
def main(config):
logger = config.get_logger('train')
--------------------------------------------------------------------
def get_logger(self, name, verbosity=2):
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
assert verbosity in self.log_levels, msg_verbosity
logger = logging.getLogger(name)
logger.setLevel(self.log_levels[verbosity])
return logger
그 다음 해당 줄이 실행되는데.. init_obj가 무엇일까
init_obj 함수의 attribute를 보면 name과 module을 받아왔는데, 여기서 agument로 지정한 module_data는 train.py 파일에서 필요한 패키지 및 모듈을 import 시 지정한 alias로 다음과 같이 data_loader.data_loaders 이다.
즉, init_obj는 MNIST 데이터를 불러오기 위한 모듈 자체를 argument로 받는다.
import data_loader.data_loaders as module_data
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
--------------------------------------------------
def init_obj(self, name, module, *args, **kwargs):
"""
Finds a function handle with the name given as 'type' in config, and returns the
instance initialized with corresponding arguments given.
`object = config.init_obj('name', module, a, b=1)`
is equivalent to
`object = module.name(a, b=1)`
"""
module_name = self[name]['type']
module_args = dict(self[name]['args'])
assert all([k not in module_args for k in kwargs]), 'Overwriting kwargs given in config file is not allowed'
module_args.update(kwargs)
return getattr(module, module_name)(*args, **module_args)
여기서 self[name]은 class자체인 config이며 getitem을 통해 name과 맞는 config데이터에서 type을 불러오는데, 여기서 name은 data_loader가 된다.
따라서 해당 config 정보를 보면, 다음과 같이 ministdataloader에 대한 정보임을 알 수 있다.
다시 init_obj를 보면, kwargs가 현재 지정된 config 내에 없다면 에러를 출력하고 만약 있다면,
module_arg에 kwargs를 update한다.
그리고 마지막으로 getattr()을 통해서 해당 모듈, 즉 data_loders 모듈에서 module_name(MnistDataLoader)와 같은 객체의 속성값을 가져온다.
module_args
"data_loader": {
"type": "MnistDataLoader",
"args":{
"data_dir": "data/",
"batch_size": 128,
"shuffle": true,
"validation_split": 0.1,
"num_workers": 2
}
}
여기서 잠깐!
getattr를 왜 쓸까?
사실 어떤 모듈을 부르거나 클래스를 부를 때 직접 코드안에 하드코딩을 해야 하는데 그렇게 하기엔 비효율적일 떄가 있다.
예를 들어 mnist1data 를 불렀는데 이번엔 mnist2data를 부르게 되는 것과 같이 일일이 다 하드코딩 할 순 없는 노릇이다.
따라서, config json file만 바꿔서 사용할 수 있게 해준 것이 getattr 방식이 지원을 해준다.
그리고 data_loader와 동일하게 model, loss, metric, trainable_params, optimizers, lr_scheduler도 똑같은 방법으로 config 정보를 불러온다.
# setup data_loader instances
data_loader = config.init_obj('data_loader', module_data)
valid_data_loader = data_loader.split_validation()
# build model architecture, then print to console
model = config.init_obj('arch', module_arch)
logger.info(model)
# prepare for (multi-device) GPU training
device, device_ids = prepare_device(config['n_gpu'])
model = model.to(device)
if len(device_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=device_ids)
# get function handles of loss and metrics
criterion = getattr(module_loss, config['loss'])
metrics = [getattr(module_metric, met) for met in config['metrics']]
# build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = config.init_obj('optimizer', torch.optim, trainable_params)
lr_scheduler = config.init_obj('lr_scheduler', torch.optim.lr_scheduler, optimizer)
마지막으로 trainer 모듈의 Trainer 클래스 객체를 생성하는데, 이 Trainer 클래스에서 지금까지 불러온 config 정보를 토대로 다양한 값을 저장하고, Trainer의 method로 지정된 train을 실행시키면 학습이 시작되는 것을 알 수 있다.
trainer = Trainer(model, criterion, metrics, optimizer,
config=config,
device=device,
data_loader=data_loader,
valid_data_loader=valid_data_loader,
lr_scheduler=lr_scheduler)
trainer.train()