2022.7.7 농산물 품질분류 프로젝트 3일차

정성우·2022년 7월 7일
0

학습한 내용

import os
import torch
import torch.nn as nn

"""모델 가중치 저장할 폴더 생성"""
os.makedirs("./weights", exist_ok=True)

"""device"""
"""1. device 설정"""
# device = torch.device("mps")
# cpu 사용자
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"

"""하이퍼 파레메타 값 세팅"""
batch_size = 20
num_epochs = 10
val_every = 1
save_weights_dir = "./weights"
data_path = "./data/"
nc = 54
lr = 0.025
criterion = nn.CrossEntropyLoss().to(device)
import glob
import os
from PIL import Image
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, data_path, mode, transform=None):
        """ init : 초기값 설정 """
        """데이터 가져오기 전체 데이터 경로 불러오기"""
        self.all_data = sorted(
            glob.glob(os.path.join(data_path, mode,"image", "*", "*.png")))
        self.transform = transform

    def __getitem__(self, index):
        data_path = self.all_data[index]
        # print("data_path info >> ", data_path)
        data_path_split = data_path.split("\\")
        labels_temp = data_path_split[2]
        # print("data_path_split info >> ", labels_temp)
		#폴더명을 기준으로 품목_품종_등급까지 한번에 각각의 라벨로 설정
        label = 0
        if "applefujiL" == labels_temp:
            label = 0
        elif "applefujiM" == labels_temp:
            label = 1
        elif "applefujiS" == labels_temp:
            label = 2
        elif "appleyanggwangL" == labels_temp:
            label = 3
        elif "appleyanggwangM" == labels_temp:
            label = 4
        elif "appleyanggwangS" == labels_temp:
            label = 5
        elif "cabbagegreenL" == labels_temp:
            label = 6
        elif "cabbagegreenM" == labels_temp:
            label = 7
        elif "cabbagegreenS" == labels_temp:
            label = 8
        elif "cabbageredL" == labels_temp:
            label = 9
        elif "cabbageredM" == labels_temp:
            label = 10
        elif "cabbageredS" == labels_temp:
            label = 11
        elif "chinesecabbageL" == labels_temp:
            label = 12
        elif "chinesecabbageM" == labels_temp:
            label = 13
        elif "chinesecabbageS" == labels_temp:
            label = 14
        elif "garlicuiseongL" == labels_temp:
            label = 15
        elif "garlicuiseongM" == labels_temp:
            label = 16
        elif "garlicuiseongS" == labels_temp:
            label = 17
        elif "mandarinehallabongL" == labels_temp:
            label = 18
        elif "mandarinehallabongM" == labels_temp:
            label = 19
        elif "mandarinehallabongS" == labels_temp:
            label = 20
        elif "mandarineonjumilgamL" == labels_temp:
            label = 21
        elif "mandarineonjumilgamM" == labels_temp:
            label = 22
        elif "mandarineonjumilgamS" == labels_temp:
            label = 23
        elif "onionredL" == labels_temp:
            label = 24
        elif "onionredM" == labels_temp:
            label = 25
        elif "onionredS" == labels_temp:
            label = 26
        elif "onionwhiteL" == labels_temp:
            label = 27
        elif "onionwhiteM" == labels_temp:
            label = 28
        elif "onionwhiteS" == labels_temp:
            label = 29
        elif "pearchuhwangL" == labels_temp:
            label = 30
        elif "pearchuhwangM" == labels_temp:
            label = 31
        elif "pearchuhwangS" == labels_temp:
            label = 32
        elif "pearsingoL" == labels_temp:
            label = 33
        elif "pearsingoM" == labels_temp:
            label = 34
        elif "pearsingoS" == labels_temp:
            label = 35
        elif "persimmonbansiL" == labels_temp:
            label = 36
        elif "persimmonbansiM" == labels_temp:
            label = 37
        elif "persimmonbansiS" == labels_temp:
            label = 38
        elif "persimmonbooyuL" == labels_temp:
            label = 39
        elif "persimmonbooyuM" == labels_temp:
            label = 40
        elif "persimmonbooyuS" == labels_temp:
            label = 41
        elif "persimmondaebongL" == labels_temp:
            label = 42
        elif "persimmondaebongM" == labels_temp:
            label = 43
        elif "persimmondaebongS" == labels_temp:
            label = 44
        elif "potatoseolbongL" == labels_temp:
            label = 45
        elif "potatoseolbongM" == labels_temp:
            label = 46
        elif "potatoseolbongS" == labels_temp:
            label = 47
        elif "potatosumiL" == labels_temp:
            label = 48
        elif "potatosumiM" == labels_temp:
            label = 49
        elif "potatosumiS" == labels_temp:
            label = 50
        elif "radishwinterradishL" == labels_temp:
            label = 51
        elif "radishwinterradishM" == labels_temp:
            label = 52
        elif "radishwinterradishS" == labels_temp:
            label = 53


        images = Image.open(data_path).convert("RGB")

        if self.transform is not None:
            images = self.transform(images)

        # print(images, label)
        return images, label

    def __len__(self):
        return len(self.all_data)
#모델 정의
import torch.nn as nn
import torchvision.models as models
import torch
import configer

device = configer.device

"""https://tutorials.pytorch.kr/beginner/fineturing_torchvision_models_tutorial.html"""


def initialize_model(model_name, num_classes, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        model_ft.classifier[1] = nn.Conv2d(
            512, num_classes, kernel_size=(1, 1), stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size
# data augmentation, 모델학습 및 결과 계산과 best 모델 저장
from numpy import argmax
import torchvision.transforms as transforms
import torch
import os
from tqdm import tqdm

"""1. aug 2. train loop 3. val loop 4. save model 5. eval"""


def data_augmentation():
    """ data augmentation 함수"""
    data_transform = {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(p=0.4),
            transforms.RandomVerticalFlip(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2])
        ]),
        'test': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2])
        ])
    }
    return data_transform


"""train loop"""


def train(num_epoch, model, train_loader, test_loader, criterion, optimizer,
          save_dir, val_every, device):

    print("String... train !!! ")
    best_loss = 9999
    for epoch in range(num_epoch):
        for i, (imgs, labels) in enumerate(train_loader):
            print(i)
            imgs, labels = imgs.to(device), labels.to(device)
            output = model(imgs)

            loss = criterion(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, argmax = torch.max(output, 1)
            acc = (labels == argmax).float().mean()

            print("Epoch [{}/{}], Step [{}/{}], Loss : {:.4f}, Acc : {:.2f}%".format(
                epoch + 1, num_epoch, i +
                1, len(train_loader), loss.item(), acc.item() * 100
            ))

            if (epoch + 1) % val_every == 0:
                avg_loss = validation(
                    epoch + 1, model, test_loader, criterion, device)
                if avg_loss < best_loss:
                    print("Best prediction at epoch : {} ".format(epoch + 1))
                    print("Save model in", save_dir)
                    best_loss = avg_loss
                    save_model(model, save_dir)

    save_model(model, save_dir, file_name="last.pt")


def validation(epoch, model, test_loader, criterion, device):
    print("Start validation # {}".format(epoch))
    model.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        total_loss = 0
        cnt = 0
        for i, (imgs, labels) in enumerate(test_loader):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            total += imgs.size(0)
            _, argmax = torch.max(outputs, 1)
            correct += (labels == argmax).sum().item()
            total_loss += loss
            cnt += 1
        avg_loss = total_loss / cnt
        print("Validation # {} Acc : {:.2f}% Average Loss : {:.4f}%".format(
            epoch, correct / total * 100, avg_loss
        ))

    model.train()
    return avg_loss

  
def save_model(model, save_dir, file_name="best.pt"):
    output_path = os.path.join(save_dir, file_name)
    torch.save(model.state_dict(), output_path)


def eval(model, test_loader, device):
    print("Starting evaluation")
    model.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for i, (imgs, labels) in tqdm(enumerate(test_loader)):
            imgs, labels = imgs.to(device), labels.to(device)

            outputs = model(imgs)
            # 점수가 가장 높은 클래스 선택
            _, argmax = torch.max(outputs, 1)
            total += imgs.size(0)
            correct += (labels == argmax).sum().item()

        print("Test acc for image : {} ACC : {:.2f}".format(
            total, correct / total * 100))
        print("End test.. ")
#실행 파일
import torch.cuda
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from zmq import device
import utils_file
import models_build
from customdataset import CustomDataset
import configer
import torch
"""
models.py -> 학습할 모델 build 파일 
utils_file.py -> 여려가지 잡동사니 ex) Image show 필요한 함수 구현 곳 
customdataset.py -> 학습데이터를 가져오기위한 데이터셋 구성 
config.py -> 하이퍼파라메타 값 세팅 하는곳 
"""

"""1. augmentation setting"""
data_transform = utils_file.data_augmentation()

"""2. data set setting"""
# data path, mode, transform
train_data = CustomDataset(data_path=configer.data_path,
                           mode="train", transform=data_transform['train'])
test_data = CustomDataset(data_path=configer.data_path,
                          mode="val", transform=data_transform['test'])

"""디버그"""
# for i in train_data:
#     pass

"""3. data loader setting"""
train_loader = DataLoader(
    train_data, batch_size=configer.batch_size, shuffle=True, drop_last=True)
test_loader = DataLoader(
    test_data, batch_size=configer.batch_size, shuffle=False, drop_last=True)

# for data, target in test_loader:
#     print(data, target)
"""4. model call"""
net, image_size = models_build.initialize_model(
    "resnet", num_classes=configer.nc)

"""5. 하이퍼파레메타 값 call loss 함수 호출 optim, lr_scheduler"""
criterion = configer.criterion
optimizer = optim.SGD(net.parameters(), lr=configer.lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=4, gamma=0.1)
"""6. train loop 함수 호출"""
utils_file.train(configer.num_epochs, net, train_loader, test_loader, criterion, optimizer, configer.save_weights_dir,
                 configer.val_every, configer.device)

실행결과

학습한 내용 중 어려웠던 점 또는 해결못한 것들
cpu로 처리하니까 너무 오래 걸림

해결방법 작성

학습 소감
파라미터 설정과 학습모델 설정을 어떻게 해야될지 잘 모르겠음

0개의 댓글