import yaml # conda install PyYAML
import os
from train.trainer import Trainer
from util.dataloader import MyTrainSetWrapper
CONFIG_PATH = os.path.join(os.path.dirname( __file__ ), "config/")
os.path.dirname() : use relative directory path
def main(model_name):
config = yaml.load(open(CONFIG_PATH + str(model_name) + ".yaml", "r"), Loader=yaml.FullLoader)
trainset = MyTrainSetWrapper(**config["train"])
downstream = Trainer(trainset, model_name, config)
downstream.train()
if __name__ == "__main__":
main("resnet_50")
CONFIG_PATH + str(model_name) + ".yaml" : generate model_name.yaml in CONFIG_PATH