from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
train_set = datasets.MNIST(root='MNIST_data/',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()])
)
test_set = datasets.MNIST(root='MNIST_data/',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()])
)
batch_size = 16
train_loader = DataLoader(train_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)
dataiter = iter(train_loader)
next(dataiter)
images, labels = next(dataiter)
print(images.shape)
print(labels)
print(labels.shape)
from torchvision import utils
import numpy as np
img = utils.make_grid(images)
npimg = img.numpy()
print(npimg.shape)
print(np.transpose(npimg,(1,2,0)).shape)
import matplotlib.pyplot as plt
plt.figure(figsize=(10,7))
plt.imshow(np.transpose(npimg,(1,2,0)))
print(labels)
plt.show()
torch.Size([16, 1, 28, 28])
tensor([2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8])
torch.Size([16])
(3, 62, 242)
(62, 242, 3)
tensor([2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8])

추가코드
one_img = images[1]
print(one_img.shape)
one_npimg = one_img.squeeze().numpy()
plt.title(f'"{labels[1]} " image')
plt.imshow(one_npimg, cmap='gray')
plt.show()

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import ssl
import torch
ssl._create_default_https_context = ssl._create_unverified_context
train_set = datasets.MNIST(root='MNIST_data/',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()])
)
test_set = datasets.MNIST(root='MNIST_data/',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()])
)
batch_size = 100
train_loader = DataLoader(train_set,
batch_size=batch_size,
shuffle=True,
drop_last=True)
import torch.nn as nn
model = nn.Sequential(
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,10)
)
import torch.optim as optim
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for epoch in range(20):
avg_loss = 0
total_batch = len(train_loader)
for x_train, y_train in train_loader:
x_train = x_train.view(-1, 28*28)
optimizer.zero_grad()
hypothesis = model(x_train)
loss = loss_func(hypothesis, y_train)
loss.backward()
optimizer.step()
avg_loss += loss/total_batch
print(f'epoch:{epoch+1}, avg_loss:{avg_loss:.4f}')
import random
import matplotlib.pyplot as plt
with torch.no_grad():
x_test = test_set.test_data.view(-1, 28*28).float()
y_test = test_set.test_labels
prediction = model(x_test)
correction_prediction = torch.argmax(prediction, dim=1) == y_test
accuracy = correction_prediction.float().mean()
print('accuracy : {:2.2f}%'.format(accuracy * 100))
print()
r = random.randint(0,len(test_set) - 1)
x_single_data = test_set.test_data[r:r+1].view(-1, 28 * 28).float()
y_single_data = test_set.test_labels[r:r+1]
print('target label:', y_single_data.item())
s_prediction = model(x_single_data)
print('model prediction:', torch.argmax(s_prediction,dim=1).item())
plt.imshow(test_set.test_data[r:r+1].view(28,28), cmap='gray')
plt.show()
epoch:1, avg_loss:0.6124
epoch:2, avg_loss:0.2336
epoch:3, avg_loss:0.1637
epoch:4, avg_loss:0.1263
epoch:5, avg_loss:0.1017
epoch:6, avg_loss:0.0843
epoch:7, avg_loss:0.0713
epoch:8, avg_loss:0.0610
epoch:9, avg_loss:0.0525
epoch:10, avg_loss:0.0449
epoch:11, avg_loss:0.0392
epoch:12, avg_loss:0.0340
epoch:13, avg_loss:0.0296
epoch:14, avg_loss:0.0258
epoch:15, avg_loss:0.0221
epoch:16, avg_loss:0.0192
epoch:17, avg_loss:0.0163
epoch:18, avg_loss:0.0140
epoch:19, avg_loss:0.0120
epoch:20, avg_loss:0.0105
C:\Users\hi\Desktop\PS\python_lib\lib\site-packages\torchvision\datasets\mnist.py:81: UserWarning: test_data has been renamed data
warnings.warn("test_data has been renamed data")
C:\Users\hi\Desktop\PS\python_lib\lib\site-packages\torchvision\datasets\mnist.py:71: UserWarning: test_labels has been renamed targets
warnings.warn("test_labels has been renamed targets")
accuracy : 99.64%
target label: 4
model prediction: 4
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import ssl
import torch
ssl._create_default_https_context = ssl._create_unverified_context
train_set = datasets.FashionMNIST(root='FashionMNIST_data/',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()])
)
test_set = datasets.FashionMNIST(root='FashionMNIST_data/',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()])
)
batch_size = 100
train_loader = DataLoader(train_set,
batch_size=batch_size,
shuffle=True,
drop_last=True)
test_loader = DataLoader(train_set,
batch_size=batch_size,
shuffle=True,
drop_last=True)
import torch.nn as nn
import torch.nn.functional as F
class ImageNN(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784,256)
self.fc2 = nn.Linear(256,128)
self.fc3 = nn.Linear(128,10)
def forward(self,x):
x = x.view(-1, 784)
out = F.relu(self.fc1(x))
out = F.relu(self.fc2(out))
y = self.fc3(out)
return y
import torch.optim as optim
model = ImageNN()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
def train(model, train_loader, optimizer):
for x_train, y_train in train_loader:
# x_train = x_train.view(-1, 28*28)
optimizer.zero_grad()
hypothesis = model(x_train)
loss = loss_func(hypothesis, y_train)
loss.backward()
optimizer.step()
def evaluate(model, test_loader):
test_loss = 0
correct = 0
with torch.no_grad():
for x_test, y_test in test_loader:
hypothesis = model(x_test)
test_loss += F.cross_entropy(hypothesis, y_test).item()
pred = torch.argmax(hypothesis, dim=1)
correct += pred.eq(y_test.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_accuracy = 100 * correct / len(test_loader.dataset)
return test_loss, test_accuracy
for epoch in range(20):
train(model, train_loader, optimizer)
test_loss, test_accuracy = evaluate(model, test_loader)
print(f'epoch{epoch+1}, loss:{test_loss:.4f}, accuracy:{test_accuracy:2.2f}%')
epoch1, loss:0.0102, accuracy:64.96%
epoch2, loss:0.0075, accuracy:71.43%
epoch3, loss:0.0065, accuracy:77.09%
epoch4, loss:0.0059, accuracy:79.60%
epoch5, loss:0.0054, accuracy:81.45%
epoch6, loss:0.0051, accuracy:82.44%
epoch7, loss:0.0049, accuracy:82.59%
epoch8, loss:0.0048, accuracy:83.26%
epoch9, loss:0.0048, accuracy:82.80%
epoch10, loss:0.0045, accuracy:84.14%
epoch11, loss:0.0044, accuracy:84.66%
epoch12, loss:0.0044, accuracy:84.70%
epoch13, loss:0.0043, accuracy:85.07%
epoch14, loss:0.0042, accuracy:85.38%
epoch15, loss:0.0042, accuracy:85.50%
epoch16, loss:0.0041, accuracy:85.78%
epoch17, loss:0.0041, accuracy:85.68%
epoch18, loss:0.0040, accuracy:85.97%
epoch19, loss:0.0040, accuracy:86.19%
epoch20, loss:0.0040, accuracy:86.12%