[머신러닝] 선형회귀 (Linear Regression)

조세은·2022년 10월 28일
0

머신러닝

목록 보기
1/2

넘파이로 선형회귀 구현하기

import numpy as np
import matplotlib.pyplot as plt

"""
선형 회귀 Y = a * X + b

weight 와 bias 를 조절해가며 학습 데이터와 자신의 예측 데이터 간의 오차를 줄이는 작업
-> "학습" , 경사 하강법을 사용하여 진행함.
"""

"""
학습 데이터
문제 : x 

정답 : y
"""
x = np.array([[8.70153760], [3.90825773], [1.89362433], [3.28730045],
[7.39333004], [2.98984649], [2.25757240], [9.84450732], [9.94589513], [5.48321616]])

y = np.array([[5.64413093], [3.75876583], [3.87233310], [4.40990425],
[6.43845020], [4.02827829], [2.26105955], [7.15768995], [6.29097441], [5.19692852]])

"""
경사하강법 과정
1. a = 0, b = 0 
-> 가중치를 0혹은 무작위로 초기화

2. Y - (a * X + b) 
-> 예측 값과 실제 값 사이의 차이 ( error ) 를 도출

3. 해당 에러를 경사하강법을 구하는 수식에 대입하여 각 가중치의 변화값(delta)를 보내줌.
-> loss function 을 미분시켜 그래프의 최소 값을 뜻하는 극소값을 찾는 과정
"""

# 학습률 결정 learning_rate
# 얼만큼 반복 n_iter

class LinearRegression:

    def __init__(self, x,y, learning_rate = 0.01, n_iter = 300):
        self.x = x
        self.y = y
        self.lr = learning_rate
        self.n_iter = n_iter

    #현재 가중치로 테스트 데이터 예측
    def prediction(self):
        self.equation = self.x*self.a + self.b
        return self.equation

    #오차 값을 통해 가중치 갱신
    def update_ab(self):
        #a를 업데이트하는 규칙을 만듬.
        self.delta_a = -(self.lr*(2/len(self.error))*(np.dot(self.x.T, self.error)))

        #b를 업데이트하는 규칙을 만듬
        self.delta_b = -(self.lr*(2/len(self.error))*np.sum(self.error))

        return self.delta_a, self.delta_b

    #경사하강법을 통한 가장 적합한 가중치 도출
    def gradient_descent(self):
        self.a = np.zeros((1,1))
        self.b = np.zeros((1,1))

        for i in range(self.n_iter):
            self.error = self.y - self.prediction()
            self.a_delta, self.b_delta = self.update_ab()
            self.a -= self.a_delta
            self.b -= self.b_delta


        return self.a, self.b

    #그래프 시각화 함수
    def plotting_graph(self):
        self.a ,self.b = self.gradient_descent()
        self.y_pred = self.a[0,0]*self.x*+self.b
        plt.scatter(self.x, self.y)
        plt.plot(self.x, self.y_pred)
        plt.savefig("new.png")

seny=LinearRegression(x,y)
seny.plotting_graph()
profile
학부생 3학년

0개의 댓글