$ torchrun --nproc_per_node=1 tools/train.py -c configs/rtdetr/rtdetr_r50vd_6x_coco.yml -r output/rtdetr_r50vd_6x_coco_from_paddle.pth --test-only
테스트는 train.py에서 rtdetr_r50vd_6x_coco.yml로 접근한다.
__include__: [
'../dataset/coco_detection.yml',
'../runtime.yml',
'./include/dataloader.yml',
'./include/optimizer.yml',
'./include/rtdetr_r50vd.yml',
]
output_dir: ./output/rtdetr_r50vd_6x_coco
하나씩 무슨 일을 살펴보는 것은 나중에 하고, 일단 test에 대해서만 알아보자.
if args.test_only:
solver.val()
train.py를 보면, solver.py를 통해 test를 한다는 것을 알 수 있다. solver.py를 보자.
"""by lyuwenyu
"""
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
from typing import Dict
from src.misc import dist
from src.core import BaseConfig
class BaseSolver(object):
def __init__(self, cfg: BaseConfig) -> None:
self.cfg = cfg
def setup(self, ):
'''Avoid instantiating unnecessary classes
'''
cfg = self.cfg
device = cfg.device
self.device = device
self.last_epoch = cfg.last_epoch
self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn)
self.criterion = cfg.criterion.to(device)
self.postprocessor = cfg.postprocessor
# NOTE (lvwenyu): should load_tuning_state before ema instance building
if self.cfg.tuning:
print(f'Tuning checkpoint from {self.cfg.tuning}')
self.load_tuning_state(self.cfg.tuning)
self.scaler = cfg.scaler
self.ema = cfg.ema.to(device) if cfg.ema is not None else None
self.output_dir = Path(cfg.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def train(self, ):
self.setup()
self.optimizer = self.cfg.optimizer
self.lr_scheduler = self.cfg.lr_scheduler
# NOTE instantiating order
if self.cfg.resume:
print(f'Resume checkpoint from {self.cfg.resume}')
self.resume(self.cfg.resume)
self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, \
shuffle=self.cfg.train_dataloader.shuffle)
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
def eval(self, ):
self.setup()
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
if self.cfg.resume:
print(f'resume from {self.cfg.resume}')
self.resume(self.cfg.resume)
def state_dict(self, last_epoch):
'''state dict
'''
state = {}
state['model'] = dist.de_parallel(self.model).state_dict()
state['date'] = datetime.now().isoformat()
# TODO
state['last_epoch'] = last_epoch
if self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
if self.lr_scheduler is not None:
state['lr_scheduler'] = self.lr_scheduler.state_dict()
# state['last_epoch'] = self.lr_scheduler.last_epoch
if self.ema is not None:
state['ema'] = self.ema.state_dict()
if self.scaler is not None:
state['scaler'] = self.scaler.state_dict()
return state
def load_state_dict(self, state):
'''load state dict
'''
# TODO
if getattr(self, 'last_epoch', None) and 'last_epoch' in state:
self.last_epoch = state['last_epoch']
print('Loading last_epoch')
if getattr(self, 'model', None) and 'model' in state:
if dist.is_parallel(self.model):
self.model.module.load_state_dict(state['model'])
else:
self.model.load_state_dict(state['model'])
print('Loading model.state_dict')
if getattr(self, 'ema', None) and 'ema' in state:
self.ema.load_state_dict(state['ema'])
print('Loading ema.state_dict')
if getattr(self, 'optimizer', None) and 'optimizer' in state:
self.optimizer.load_state_dict(state['optimizer'])
print('Loading optimizer.state_dict')
if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
print('Loading lr_scheduler.state_dict')
if getattr(self, 'scaler', None) and 'scaler' in state:
self.scaler.load_state_dict(state['scaler'])
print('Loading scaler.state_dict')
def save(self, path):
'''save state
'''
state = self.state_dict()
dist.save_on_master(state, path)
def resume(self, path):
'''load resume
'''
# for cuda:0 memory
state = torch.load(path, map_location='cpu')
self.load_state_dict(state)
def load_tuning_state(self, path,):
"""only load model for tuning and skip missed/dismatched keys
"""
if 'http' in path:
state = torch.hub.load_state_dict_from_url(path, map_location='cpu')
else:
state = torch.load(path, map_location='cpu')
module = dist.de_parallel(self.model)
# TODO hard code
if 'ema' in state:
stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])
else:
stat, infos = self._matched_state(module.state_dict(), state['model'])
module.load_state_dict(stat, strict=False)
print(f'Load model.state_dict, {infos}')
@staticmethod
def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]):
missed_list = []
unmatched_list = []
matched_state = {}
for k, v in state.items():
if k in params:
if v.shape == params[k].shape:
matched_state[k] = params[k]
else:
unmatched_list.append(k)
else:
missed_list.append(k)
return matched_state, {'missed': missed_list, 'unmatched': unmatched_list}
def fit(self, ):
raise NotImplementedError('')
def val(self, ):
raise NotImplementedError('')
보게 되면 에러를 raise하는 것을 알 수 있다. 찾아보니, BaseSolver를 상속한 클래스에서 test를 수행한다.
'''
by lyuwenyu
'''
import time
import json
import datetime
import torch
from src.misc import dist
from src.data import get_coco_api_from_dataset
from .solver import BaseSolver
from .det_engine import train_one_epoch, evaluate
class DetSolver(BaseSolver):
def fit(self, ):
print("Start training")
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
# best_stat = {'coco_eval_bbox': 0, 'coco_eval_masks': 0, 'epoch': -1, }
best_stat = {'epoch': -1, }
start_time = time.time()
for epoch in range(self.last_epoch + 1, args.epoches):
if dist.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,
args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)
self.lr_scheduler.step()
if self.output_dir:
checkpoint_paths = [self.output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_step == 0:
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
dist.save_on_master(self.state_dict(epoch), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
)
# TODO
for k in test_stats.keys():
if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat['epoch'] = epoch
best_stat[k] = test_stats[k][0]
print('best_stat: ', best_stat)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if self.output_dir and dist.is_main_process():
with (self.output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(self.output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
self.output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def val(self, ):
self.eval()
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor,
self.val_dataloader, base_ds, self.device, self.output_dir)
if self.output_dir:
dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")
return
제공된 코드에서 DetSolver 클래스의 val() 메서드가 모델의 검증(테스트) 과정을 정의한 부분입니다. DetSolver는 BaseSolver를 상속한 클래스이며, val() 메서드를 재정의하여 테스트를 수행합니다.
val() 메서드의 동작 방식
이 메서드는 모델의 검증(테스트)을 다음 단계로 수행합니다:
초기화 (self.eval() 호출):
self.eval() 메서드는 BaseSolver에서 정의된 메서드로, 검증을 위해 필요한 초기화 작업을 수행합니다. 이 작업에는 데이터 로더 초기화 및 체크포인트 로드 등이 포함됩니다.
데이터셋 API 가져오기:
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)는 검증에 사용할 COCO API를 가져옵니다. COCO 데이터셋을 기반으로 검증을 수행할 때 유용합니다.
모델 선택:
module = self.ema.module if self.ema else self.model는 만약 EMA(지수 이동 평균) 모델이 사용되고 있다면, 이를 사용하도록 설정합니다. 그렇지 않으면 기본 모델을 사용합니다.
평가 수행:
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir)에서 실제로 모델의 평가가 수행됩니다.
evaluate 함수는 모델이 데이터셋을 기반으로 예측을 수행하고, 그 결과를 평가합니다.
이 과정에서 test_stats에는 테스트 결과가, coco_evaluator에는 COCO 평가 지표가 담깁니다.
평가 결과 저장:
dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")는 평가 결과를 파일로 저장합니다. 이 결과는 COCO 형식의 평가 지표로 저장됩니다.
종료:
return 문으로 검증이 종료됩니다.
요약
val() 메서드는 모델을 평가(테스트)하기 위해 필요한 모든 초기화를 수행한 후, 데이터셋을 통해 모델의 성능을 평가하고 그 결과를 저장하는 과정을 수행합니다. 이 메서드는 주로 COCO 데이터셋을 기반으로 하는 객체 탐지 모델에서 사용됩니다.
따라서, test 결과도 차량만 하기 위해서는 coco dataset api가 아닌 다른 api로 실시해야한다. api만 바꿔서 그 api를 가져오는 코드가 원활히 작동한다는 보장은 없긴 하지만