[230105-2] Perception in self driving car

junyong lee·2023년 1월 6일
0

Pytorch LeNet5 MNIST 학습


main.py

from turtle import down
import argparse
import sys, os
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data.dataloader import DataLoader
import torch.optim as optim

from model.models import *
from loss.loss import *
from util.tools import *

def parse_args():
    parser = argparse.ArgumentParser(description="MNIST")
    parser.add_argument('--mode', dest='mode', help="train / eval / test",
                        default=False, type=str)
    parser.add_argument('--download', dest='download', help="download MNIST dataset",
                        default=False, type=bool)
    parser.add_argument('--output_dir', dest='output_dir', help="output directory",
                        default='./output', type=str)
    parser.add_argument('--checkpoint', dest='checkpoint', help="checkpoint trained model",
                        default=None, type=str)
    
    if len(sys.argv) == 1:
        parser.print_help()
        sys.exit()
    args = parser.parse_args()
    return args

def get_data():
    my_transform = transforms.Compose([
        transforms.Resize([32, 32]),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (1.0,))
    ])
    download_root = "./mnist_dataset"
    train_dataset = MNIST(root=download_root,
                          transform=my_transform,
                          train=True,
                          download=args.download)
    eval_dataset = MNIST(root=download_root,
                         transform=my_transform,
                         train=False,
                         download=args.download)
    test_dataset = MNIST(root=download_root,
                         transform=my_transform,
                         train=False,
                         download=args.download)
    
    return train_dataset, eval_dataset, test_dataset

def main():
    print(torch.__version__)
    
    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
    
    if torch.cuda.is_available():
        print("gpu")
        device = torch.device("cuda")
    else:
        print("cpu")
        device = torch.device("cpu")
    
    # Get MNIST Dataset
    train_dataset, eval_dataset, test_dataset = get_data()

    # Make DataLoader
    train_loader = DataLoader(train_dataset,
                              batch_size=8,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=True,
                              shuffle=True)
    eval_loader = DataLoader(eval_dataset,
                            batch_size=1,
                            num_workers=0,
                            pin_memory=True,
                            drop_last=False,
                            shuffle=False)
    test_loader = DataLoader(test_dataset,
                        batch_size=1,
                        num_workers=0,
                        pin_memory=True,
                        drop_last=False,
                        shuffle=False)
    
    _model = get_model('lenet5')
    
    # LeNet5

    if args.mode == "train": # python main.py --mode "train" --download 1 --output_dir ./output
        model = _model(batch=8, n_classes=10, in_channel=1, in_width=32, in_height=32, is_train=True)
        model.to(device)
        model.train() # trian
        
        # optimizer & scheduler
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
        
        criterion = get_criterion(crit='mnist', device=device)
        
        epoch = 15
        iter = 0
        for e in range(epoch):
            total_loss = 0
            for i, batch in enumerate(train_loader):
                img = batch[0]
                gt = batch[1]
                
                img = img.to(device)
                gt = gt.to(device)
                
                out = model(img)
                
                loss_val = criterion(out, gt)
                
                # backpropagation
                loss_val.backward()
                optimizer.step()
                optimizer.zero_grad()
                
                total_loss += loss_val.item()
                
                if iter % 100 == 0:
                    print(f"{e} epoch {iter} iter loss : {loss_val.item()}")
                iter += 1
            
            mean_loss = total_loss / i
            scheduler.step()
            
            print(f"->{e} epoch mean loss : {mean_loss}")
            torch.save(model.state_dict(), args.output_dir + "/model_epoch" + str(e)+".pt")
        print("Train end")
        
        
    elif args.mode == "eval":
        # python main.py --mode "eval" --download 1 --output_dir ./output \ 
        # --checkpoint ./output/model_epoch2.pt
        model = _model(batch=1, n_classes=10, in_channel=1, in_width=32, in_height=32)
        # load trained model
        checkpoint = torch.load(args.checkpoint)
        model.load_state_dict(checkpoint)
        model.to(device)
        model.eval() # not train()
        
        acc = 0
        num_eval = 0
        
        for i, batch in enumerate(eval_loader):
            img = batch[0]
            gt = batch[1] # ground thruth
            
            img = img.to(device)
            
            # inference
            out = model(img)
            
            out = out.cpu()
            
            if out == gt:
                acc += 1
            num_eval += 1
        
        print(f"Evaluation Score : {acc} / {num_eval}")
            
    elif args.mode == "test":
        # python main.py --mode "test" --download 1 --output_dir ./output \
        # --checkpoint ./output/model_epoch2.pt
        model = _model(batch=1, n_classes=10, in_channel=1, in_width=1, in_height=1)
        checkpoint = torch.load(args.checkpoint)
        model.load_state_dict(checkpoint)
        model.to(device)
        model.eval() # not train()
        
        for i, batch in enumerate(test_loader):
            img = batch[0]
            img = img.to(device)
            
            # inference
            out = model(img)
            out = out.cpu()
            
            print(out)
            
            # show result
            show_img(img.cpu().numpy(), str(out.item()))
            
if __name__ == "__main__":
    args = parse_args()
    main()
    
# image classification sequential
# 1. Get dataset
# 2. Make Dataloader(학습에 사용될 DB 구축)
# 3. design model
# 4. training
# 5. optimizer & scheduler
# 6. loss function
# 7. forward -> loss_val
# 8. loss_val -> backpropagation -> optimizer.step(), optimizer.zero_grad()
# 9. save model

model

lenet5.py

import torch
import torch.nn as nn

class Lenet5(nn.Module):
    def __init__(self, batch, n_classes, in_channel, in_width, in_height, is_train=False):
        super().__init__()
        self.batch = batch
        self.n_classes = n_classes
        self.in_width = in_width
        self.in_height = in_height
        self.in_channel = in_channel
        self.is_train = is_train
        
        # convolution output : [(W - K + 2P)/S] + 1
        
        # [(32 - 5 + 2*0) / 1] + 1 = 28
        self.conv0 = nn.Conv2d(self.in_channel, 6, kernel_size=5, stride=1, padding=0)
        self.pool0 = nn.AvgPool2d(2, stride=2)
        self.conv1 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
        self.pool1 = nn.AvgPool2d(2, stride=2)
        self.conv2 = nn.Conv2d(16, 120, kernel_size=5, stride=1, padding=0)
        
        # fully-connected layer
        self.fc3 = nn.Linear(120, 84)
        self.fc4 = nn.Linear(84, self.n_classes)
        
    def forward(self, x):
        # x' shape : [B, C, H, W]
        x = self.conv0(x)
        x = torch.tanh(x)
        x = self.pool0(x)
        x = self.conv1(x)
        x = torch.tanh(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = torch.tanh(x)
        # change format from 4dim -> 2dim ( [B, C, H, W] -> [B, C*H*W] )
        x = torch.flatten(x, start_dim=1)
        x = self.fc3(x)
        x = torch.tanh(x)
        x = self.fc4(x)
        x = x.view(self.batch, -1)
        x = nn.functional.softmax(x, dim=1)
        
        if self.is_train is False:
            x = torch.argmax(x, dim=1)
        return x
        

models.py

from model.lenet5 import Lenet5

def get_model(model_name):
    if (model_name == "lenet5"):
        return Lenet5
    else:
        print("unknown model")

loss

loss.py

import torch
import torch.nn as nn
import sys

class MNISTloss(nn.Module):
    def __init__(self, device=torch.device('cpu')):
        super(MNISTloss, self).__init__()
        self.loss = nn.CrossEntropyLoss().to(device)
        
    def forward(self, out, gt):
        loss_val = self.loss(out, gt)
        return loss_val
    
def get_criterion(crit = "mnist", device=torch.device('cpu')):
    if crit == "mnist":
        return MNISTloss(device=device)
    else:
        print("unknown criterion")
        sys.exit(1)
        return

util

tools.py

from PIL import Image, ImageDraw
import numpy as np
import matplotlib.pyplot as plt

def show_img(img_data, text):
    _img_data = img_data * 255
    
    # 4D -> 2D
    
    _img_data = np.array(_img_data[0, 0], dtype=np.uint8)
    
    img_data = Image.fromarray(_img_data)
    draw = ImageDraw.Draw(img_data)
    
    cx, cy = int(_img_data.shape[0] / 2), int(_img_data.shape[1] / 2)
    
    # draw text in image
    if text is not None:
        draw.text((cx, cy), text)
    
    plt.imshow(img_data)
    plt.show()

https://github.com/Jun-yong-lee/pytorch_study/tree/pytorch_MNIST


profile
도전하는 개발자 지망생

0개의 댓글