최적화 알고리즘을 이용한 회귀식 추정

누디·2022년 10월 23일
0

최적화 알고리즘을 이용한 회귀식 추정

1차 함수로 선형 회귀 이해하기

y=ax+by = ax+b

위 1차 함수의 기울기는 a이고 절편은 b

선형 회귀는 기울기와 절편을 찾아주는 것

기울기와 절편을 조정해가면서 각 점들을 나타내는 함수를 찾는 것이 선형 회귀

문제 해결을 위해 데이터 준비 및 시각화

R에서 기본으로 제공하는 mtcars.csv를 이용

1974년에 미국의 모터 트렌드 잡지에 실린 1973~1974년 자동차 모델의 연료 소비, 10가지 디자인 요소, 성능을 비교한 데이터이다.

여기서 무게(x)에 따른 연비(y)를 나타냄

💻 입력코드

from pandas.io.parsers import read_csv
import matplotlib.pyplot as plt
import pandas as pd

df = read_csv('mtcars.csv')
print(df.shape)
print(df.head(5))
print(df.describe())

# 자동차의 무게(wt)와 연비(mpg)의 산점도 그래프
plt.scatter(df.wt, df.mpg)
plt.xlabel('weight')
plt.ylabel('miles per gallon')
plt.grid()
plt.savefig('wt-mpg.png')
plt.show()

💻 출력화면

스크린샷 2022-06-17 오후 11 24 40

예측값 찾기

예측값이란, 우리가 입력과 출력 데이터(x,y)를 통해 규칙(a,b)을 발견하면 모델을 만들었다고 한다.

그 모델에 대해 새로운 입력값을 넣으면 어떤 출력이 나오는데, 이 값이 모델을 통해 예측한 값

ŷ 과 w라는 문자 등장

y와 ŷ은 어떤 결과라는 점은 동일하지만 예측한 값이라는 점만 다르다.

w는 가중치를 의미하는 계수

이제 새로운 식으로 표현

y^=wx+bŷ=wx+b

💡 훈련 데이터에 잘 맞는 w와 b를 찾는 법 💡

1. 무작위로 w와 b를 정한다.
2. x에서 샘플 하나를 선택하여 ŷ을 계산
3. ŷ과 선택한 샘플의 진짜 y를 비교
4. ŷ와 y가 더 가까워지도록 w,b를 조정하기
5. 모든 샘플을 처리할 때까지 다시 2~4 반복

알고리즘 적용

  1. Genetic Algoritm

    알고리즘에 대한 설명 및 동작 원리는 Genetic-Algorithm.md

  2. Simulated Algorithm

    알고리즘에 대한 설명 및 동작 원리는 Simulted-Algorithm.md

    • 독립 변수 x (무게, weight)
    • 종속변수 y (연비, mpg)
    • 임의의 모수(기울기, 절편)를 무작위로 대입
    • 평균제곱근오차(RMSE)가 최소가 되는 해를 구한다.

결과 고찰


임의의 모수를 1 간격으로 설정한 결과

slopeList = np.arange(-100, 100, 1)
interceptList = np.arange(-100, 100, 1)

스크린샷 2022-06-18 오전 12.29.48.png

RMSE가 최소가 되는

slope : -5.000

intercept : 36.000

세부 a,b를 위해서 값을 상세히 조정한 결과

slopeList = np.arange(-10, 10, 0.1)
interceptList = np.arange(30, 40, 0.1)

스크린샷 2022-06-18 오전 12.30.35.png

slope : -5.30

intercept : 37.10

더 정확한 결과가 나왔다

모수를 더 좁게 설정할수록 더 정확한 결과가 나올 것이라 생각한다.

도출된 기울기와 절편을 다시 산포도 위에서 확인하면

스크린샷 2022-06-18 오전 1.01.16.png

위에 도출된 결과 slope와 intercept의 상수값을 이용해서 직접 산점도 위에 직선을 그리고 싶었으나 실패하였다..

따라서 다른 방법으로 시도하였다.

import seaborn as sns; sns.set()
#을 이용하여
ax = sns.lmplot(x="wt", y="mpg", data = df)
#코드를 추가하면

위 사진의 결과를 확인할 수 있다.

sns.lmplot 을 이용하면 산점도와 선형회귀선 그리고 95% 신뢰구간(Confidence interval)이 기본적으로 표현된다.

0개의 댓글