# -*- coding: utf-8 -*-
'''
Train CIFAR10 with PyTorch and Vision Transformers!
written by @kentaroy47, @arutema47
'''
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
import torchvision.transforms as transforms
import os
import argparse
import pandas as pd
import csv
from models import *
from models.vit import ViT, channel_selection
from utils import progress_bar
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# parsers
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=1e-4, type=float, help='learning rate') # resnets.. 1e-3, Vit..1e-4?
parser.add_argument('--opt', default="adam")
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
parser.add_argument('--aug', action='store_true', help='add image augumentations')
parser.add_argument('--mixup', action='store_true', help='add mixup augumentations')
parser.add_argument('--net', default='vit')
parser.add_argument('--bs', default='64')
parser.add_argument('--n_epochs', type=int, default='100')
parser.add_argument('--patch', default='4', type=int)
parser.add_argument('--cos', action='store_true', help='Train with cosine annealing scheduling')
args = parser.parse_args()
if args.cos:
from warmup_scheduler import GradualWarmupScheduler
if args.aug:
import albumentations
from warmup_scheduler import GradualWarmupScheduler
코드구현 및 설명, 설명2, 좋은 블로그 설명
import albumentations
: augmentation 구현 및 설명, 설명2
bs = int(args.bs)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
arg.bs 에서 나온 값 bs 로 배정
best_acc = 0, start_epoch = 0
: 초기값 설정
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
값 설정
trainset = torchvision.datasets.CIFAR10(root='/home/lxc/ABCPruner/data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=bs, shuffle=True, num_workers=8)
testset = torchvision.datasets.CIFAR10(root='/home/lxc/ABCPruner/data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
데이터 로드 설정
# Model
print('==> Building model..')
# net = VGG('VGG19')
if args.net=='res18':
net = ResNet18()
elif args.net=='vgg':
net = VGG('VGG19')
elif args.net=='res34':
net = ResNet34()
elif args.net=='res50':
net = ResNet50()
elif args.net=='res101':
net = ResNet101()
elif args.net=="vit":
# ViT for cifar10
net = ViT(
image_size = 32,
patch_size = args.patch,
num_classes = 10,
dim = 512, # 512
depth = 6,
heads = 8,
mlp_dim = 512,
dropout = 0.1,
emb_dropout = 0.1
)
net 설정하면 해당 훈련을 시작
net = net.to(device)
# if device == 'cuda':
# net = torch.nn.DataParallel(net) # make parallel
# cudnn.benchmark = True
# cudnn.benchmark = True
위에 device
에서 gpu가능 하면 cuda로 학습
if args.resume:
# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}-ckpt.t7'.format(args.net))
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
만약 arags.resum
이 되어있을 시 check point에서 가져온다.
assert
: 예시
.load_state_dict
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']
checkpoint
는 저장 영역에서 다시 확인 하면 된다.