[GDSC/ML] : ImageNet Training Based on Pytorch Basic Template

YOOJIN·2022년 10월 13일
0

Train.py

import yaml  # conda install PyYAML
import os
from train.trainer import Trainer
from util.dataloader import MyTrainSetWrapper

CONFIG_PATH = os.path.join(os.path.dirname( __file__ ), "config/")

os.path.dirname() : use relative directory path

def main(model_name):
    config = yaml.load(open(CONFIG_PATH + str(model_name) + ".yaml", "r"), Loader=yaml.FullLoader)
    trainset = MyTrainSetWrapper(**config["train"])
    downstream = Trainer(trainset, model_name, config)
    downstream.train()

if __name__ == "__main__":
    main("resnet_50")

CONFIG_PATH + str(model_name) + ".yaml" : generate model_name.yaml in CONFIG_PATH

GitHub

0개의 댓글