NameError: name 'pickle' is not defined

안녕하세요·2023년 8월 10일
0

파이썬 오류 기록

목록 보기
5/8
post-thumbnail

📍 문제상황

class Regression():
    import numpy as np
    import sys, os
    sys.path.append(os.pardir)
    from MNIST_data import load_mnist
    import pickle 
    
    def __init__(self):
        pass
    
    def get_data(self):
        (x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
        return x_test, t_test
    
    def init_network(self, weight_file: str):
        with open(weight_file, 'rb') as f:
            network = pickle.load(f)
        return network
    
    def predict(self, x):
        W1, W2, W3 = network['W1'], network['W2'], network['W3']
        b1, b2, b3 = network['b1'], network['b2'], network['b3']

        a1 = np.dot(x, W1) + b1
        z1 = self.sigmoid(a1)
        a2 = np.dot(z1, W2) + b2
        z2 = self.sigmoid(a2)
        a3 = np.dot(z2, W3) + b3
        y = self.softmax(a3)
        
        return y
    
    def regression(self, weight_file, batch_size:int):
        network = self.init_network(weight_file)
        predict = self.predict()
        accuracy_cnt = 0
        x, t = self.get_data()
        
        for i in range(0, len(x), batch_size):
            x_batch = x[i:i+batch_size]
            y_batch = predict(x_batch)
            p= np.argmax(y_batch, axis = 1) # 확률이 가장 높은 원소의 인덱스를 얻는다.
            accuracy_cnt += np.sum(p == t[i:i+batch_size])

            print("Accuracy:" + str(float(accuracy_cnt) / len(x)))

클래스 내부에서 pickle를 정의했는데

⚠️ 오류

NameError                                 Traceback (most recent call last)
Cell In[2], line 2
      1 test = Regression()
----> 2 test.regression(r"경로", 100)

Cell In[1], line 34, in Regression.regression(self, weight_file, batch_size)
     33 def regression(self, weight_file, batch_size:int):
---> 34     network = self.init_network(weight_file)
     35     predict = self.predict()
     36     accuracy_cnt = 0

Cell In[1], line 17, in Regression.init_network(self, weight_file)
     15 def init_network(self, weight_file: str):
     16     with open(weight_file, 'rb') as f:
---> 17         network = pickle.load(f)
     18     return network

NameError: name 'pickle' is not defined

해당 에러가 발생

🤔 원인

알아내면 추가

✅ 해결

패키지 import는 클래스 외부에 하는 것이 낫다고 한다.

# 예시
import numpy as np
import sys, os
sys.path.append(os.pardir)
from MNIST_data import load_mnist
import pickle

class Regression():(중략)
profile
반갑습니다

0개의 댓글