04/03 인공지능

윤수환·2024년 4월 3일

인공지능

목록 보기
1/10

평균 제곱 오차(MSE)

선형 회귀 - 임의의 직선을 그어 이에 대한 평균 제곱 오차를 구하고, 이 값을 가장 작게 만들어 주는 a값과 b값을 찾아가는 작업

오차 = 실제 값 - 예측 값
(오차**2 / x원소의 총 개수)

평균 제곱 오차 구하기

import numpy as np

fake_a = 3;
fake_b = 76;

x = np.array([2, 4, 6, 8]);
y = np.array([81, 93, 91, 97]);

def predict(x):
    return fake_a * x + fake_b;

predict_result = [];

for i in range(len(x)):
    predict_result.append(predict(x[i]))
    print("공부시간=%.f, 실제 점수=%.f, 예측점수=%.f" % (x[i], y[i], predict(x[i])));

n = len(x)
def mse(y, y_pred):
    return (1/n) * sum((y - y_pred)**2);
print("평균 제곱 오차: " + str(mse(y, predict_result)))

선형 회귀 모델

오차 줄이기 - 경사 하강법
적절한 학습률을 설정해 미분값이 0인 지점을 찾는 것

학습률을 너무 크게 잡으면 수렴하지 않고 발산

값을 미분할 때 궁금한 것은 a,b(기울기,y절편)
따라서 a,b를 중심으로 편미분해야함

import numpy as np
import matplotlib.pyplot as plt

x = np.array([2, 4, 6, 8]);
y = np.array([81, 93, 91, 97]);

plt.scatter(x, y);
plt.show();

a = 0;
b = 0;
lr = 0.03;
epochs = 2001;

n = len(x);

for i in range(epochs):
    y_pred = a * x + b
    error = y - y_pred
    a_diff = (2/n) * sum(-x * (error))
    b_diff = (2/n) * sum(-(error))
    a = a - lr * a_diff
    b = b - lr * b_diff

    if 1 % 100 == 0:
        print("epoch = %.f, 기울기 = %.04f, 절편 = %.04f" % (i, a, b));

y_pred = a * x + b;

plt.scatter(x, y);
plt.plot(x, y_pred,'r');
plt.show();

다중 선형 회귀

import numpy as np
import matplotlib.pyplot as plt

x1 = np.array([2, 4, 6, 8]);
x2 = np.array([0, 4, 2, 3]);
y = np.array([81, 93, 91, 97]);

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter3D(x1, x2, y);
plt.show()

a1 = 0;
a2 = 0;
b = 0;
lr = 0.01;
epochs = 2001;
n = len(x1);

for i in range(epochs):
    y_pred = a1 * x1 + a2 * x2 + b
    error = y - y_pred
    a1_diff = (2/n) * sum(-x1 * (error))
    a2_diff = (2/n) * sum(-x2 * (error))
    b_diff = (2/n) * sum(-(error))

    a1 = a1 - lr * a1_diff
    a2 = a2 - lr * a2_diff
    b = b - lr * b_diff

    if i % 100 == 0:
        print("epoch=%.f, 기울기 1 = %0.4f, 기울기 2 = %0.4f, 절편 = %0.4f" % (i, a1, a2, b));
print("실제 점수: ", y)
print("예측 점수: ", y_pred)

0개의 댓글