[Linear Regression] 단순 선형회귀 with python

Surf in Data·2022년 4월 3일
1

machine learning

목록 보기
1/7
post-thumbnail
post-custom-banner

회귀분석이란?

  • 회귀 분석은 지도학습(supervised learning)중 하나인데 지도학습이란 y=f(x)에 대하여 입력변수(x)와 출력변수(y)의 관계에 대해 모델링을 하는것이다.
  • 지도학습은 회귀(regression) 와 분류(classification)으로 나뉘게 된다.
    • 회귀: 입력변수 x에 대해서 연속형 출력 변수 y를 예측
    • 분류: 입력변수 x에 대해서 이산형 출력 변수 y(class)를 예측
  • 회귀분석에는 선형회귀와 비선형회귀가 있다.

단순 선형 회귀분석

실제 데이터 x = [1, 2, 3], y=[3, 5, 7] 이 있을때 만약 새로운 관측치 x가 4일때 y값을 어떻게 될까??
데이터로부터 우리는 y=2x+1 이라는 식을 유도할 수 있고 유추한 식에 x=4를 대입해 y값을 9라고 예측할 수 있다.
바로 이러한 x(하나의 독립변수)로 y(하나의 종속변수)를 예측하는 함수를 찾는 과정이다.

단순회귀 분석의 결과


단순 선형 회귀분석 실습

import os
import pandas as pd 
import numpy as np
import statsmodels.api as sm

boston집값 데이터 불러오기

boston = pd.read_csv("./Boston_house.csv")
boston.head()
AGE B RM CRIM DIS INDUS LSTAT NOX PTRATIO RAD ZN TAX CHAS Target
0 65.2 396.90 6.575 0.00632 4.0900 2.31 4.98 0.538 15.3 1 18.0 296 0 24.0
1 78.9 396.90 6.421 0.02731 4.9671 7.07 9.14 0.469 17.8 2 0.0 242 0 21.6
2 61.1 392.83 7.185 0.02729 4.9671 7.07 4.03 0.469 17.8 2 0.0 242 0 34.7
3 45.8 394.63 6.998 0.03237 6.0622 2.18 2.94 0.458 18.7 3 0.0 222 0 33.4
4 54.2 396.90 7.147 0.06905 6.0622 2.18 5.33 0.458 18.7 3 0.0 222 0 36.2

단순 선형회귀 분석이므로 하나의 x값을 정하기

x = boston["RM"]  #x값
target = boston['Target']  #y값
train_data = sm.add_constant(x, has_constant='add')
train_data
const RM
0 1.0 6.575
1 1.0 6.421
2 1.0 7.185
3 1.0 6.998
4 1.0 7.147
... ... ...
501 1.0 6.593
502 1.0 6.120
503 1.0 6.976
504 1.0 6.794
505 1.0 6.030

506 rows × 2 columns

회귀계수 구하기

model1 = sm.OLS(target, train_data)
fitted_model1=model1.fit()
fitted_model1.summary()
OLS Regression Results
Dep. Variable: Target R-squared: 0.484
Model: OLS Adj. R-squared: 0.483
Method: Least Squares F-statistic: 471.8
Date: Sat, 19 Mar 2022 Prob (F-statistic): 2.49e-74
Time: 21:27:43 Log-Likelihood: -1673.1
No. Observations: 506 AIC: 3350.
Df Residuals: 504 BIC: 3359.
Df Model: 1
Covariance Type: nonrobust
coef std err t P>|t| [0.025 0.975]
const -34.6706 2.650 -13.084 0.000 -39.877 -29.465
RM 9.1021 0.419 21.722 0.000 8.279 9.925
Omnibus: 102.585 Durbin-Watson: 0.684
Prob(Omnibus): 0.000 Jarque-Bera (JB): 612.449
Skew: 0.726 Prob(JB): 1.02e-133
Kurtosis: 8.190 Cond. No. 58.4


Notes:
[1] Standard Errors assume that the covariance matrix of the errors is correctly specified.

p-value는 0 이므로 RM은 유의미하다고 할 수 있고
결정계수 R-squared는 RM으로 설명할수 있는 target값의 변동이므로 0.484가 나온것을 확인할 수 있다.

fitted_model1.params  #회귀계수
const   -34.670621
RM        9.102109
dtype: float64

x(RM), y(Target)으로 추정한 회귀식 y = -34.670621x +9.102109

구한 회귀식으로 y의 예측값 구하기

#첫번째 방법
(np.dot(train_data,fitted_model1.params))
array([25.17574577, 23.77402099, 30.72803225, 29.02593787, 30.38215211,
       23.85593997, 20.05125842, 21.50759586, 16.5833549 , 19.97844155,
       23.3735282 , 20.02395209, 18.93169901, 19.47782555, 20.81583557,
       18.43108302, 19.35039603, 19.85101202, 14.99048582, 17.45715736,
       16.02812625, 19.6234593 , 21.23453259, 18.23993873, 19.25027283,
       16.29208741, 18.23993873, 20.36983223, 24.44757706, 26.07685456,
       17.32972783, 20.59738496, 19.48692766, 17.22050253, 20.81583557,
       19.33219181, 18.49479778, 18.57671676, 19.63256141, 25.35778795,
       29.26259271, 26.95065703, 21.48028953, 21.86257811, 20.57007863,
       17.04756245, 17.99418179, 20.21509638, 14.47166561, 16.31939374,
       19.60525508, 20.98877564, 24.5932108 , 19.92382889, 18.9225969 ,
       31.31056723, 23.42814085, 27.36935404, 21.26183891, 19.27757916,
       17.58458688, 19.63256141, 24.09259481, 26.87784015, 29.99076143,
       22.58164472, 18.0032839 , 18.83157581, 16.24657686, 18.89529058,
       23.73761256, 19.58705086, 20.53367019, 22.17204981, 22.42690886,
       22.54523628, 22.48152152, 21.21632837, 22.05372239, 18.79516738,
       26.55926634, 25.57623857, 22.69087002, 21.46208531, 23.4827535 ,
       25.67636177, 20.07856475, 21.0433883 , 29.10785685, 29.7632087 ,
       23.73761256, 23.62838725, 23.96516528, 21.86257811, 22.20845825,
       25.63085122, 21.42567687, 38.77429659, 36.50787146, 32.83061943,
       26.55926634, 27.05078022, 23.62838725, 21.18902204, 21.46208531,
       18.58581887, 18.44928724, 21.09800095, 24.25643277, 22.02641607,
       21.71694436, 26.45004103, 19.15014963, 20.77942714, 22.25396879,
       19.28668126, 21.54400429, 20.1331774 , 18.77696316, 17.49356579,
       18.75875894, 19.97844155, 19.58705086, 18.63132942, 18.84067792,
       19.81460358, 16.41951693, 17.14768565, 23.86504208, 16.63796755,
       24.11079902, 22.90932064, 23.32801765, 18.32185771, 17.73022063,
       22.99123962, 19.41411079, 24.07439059, 18.64043153, 21.31645157,
       21.52580007, 11.0128642 , 14.50807405, 15.09971113,  9.95701956,
       21.12530728, 16.55604857, 10.16636806, 12.5329164 , 16.27388319,
       21.05249041, 14.51717616, 10.94914944, 17.2933194 , 21.11620517,
       21.32555368, 13.31569777, 28.52532188, 20.5427723 , 24.58410869,
       22.21756036, 33.49507338, 36.34403349, 41.55954194, 18.6131252 ,
       20.86134612, 37.50000134, 18.82247371, 22.84560588, 23.60108092,
       18.80426949, 18.84978003, 16.04633047, 23.72851045, 18.65863574,
       24.91178461, 20.12407529, 22.80919744, 27.76984683, 28.86209991,
       36.00725546, 21.2527368 , 30.45496898, 25.06652047, 16.33759795,
       21.33465578, 36.60799466, 27.05988233, 25.0028057 , 30.72803225,
       28.59813875, 26.66849165, 30.66431749, 27.2237203 , 25.43970694,
       37.00848745, 31.65644737, 30.01806775, 31.53811995, 28.81658937,
       30.2729268 , 21.41657477, 34.59642857, 36.80824105, 38.45572278,
       18.94990323, 22.90932064, 17.96687546, 20.52456809, 13.97104962,
       19.57794875, 14.51717616, 18.18532608, 23.35532398, 14.58999303,
       21.59861695, 18.9225969 , 25.78558708, 19.49602977, 23.33711976,
       28.59813875, 21.43477898, 27.94278691, 25.56713646, 40.56741206,
       44.74528008, 38.51033543, 30.52778586, 35.28818885, 24.96639727,
       19.76909304, 32.79421099, 41.2136618 , 40.39447199, 26.55016423,
       20.72481448, 25.68546388, 32.30269711, 24.32014753, 25.45791115,
       28.10662487, 20.80673346, 23.20058813, 23.51916194, 16.23747476,
       16.34670006, 20.92506088, 21.99910974, 23.8832463 , 26.47734736,
       24.37476018, 23.92875684, 28.65275141, 40.5036973 , 20.92506088,
       18.8133716 , 33.17649957, 44.5541358 , 32.07514438, 27.60600887,
       30.89187022, 33.77723876, 41.76889045, 32.02053173, 30.91917654,
       15.93710516, 29.17157162, 40.84957744, 33.32213331, 19.21386439,
       18.63132942, 22.12653927, 24.83896774, 35.3336994 , 26.84143172,
       27.71523418, 31.47440519, 27.46037513, 24.32924964, 27.3329456 ,
       36.50787146, 28.7528746 , 34.91500238, 37.44538868, 29.84512768,
       24.06528848, 22.03551818, 21.84437389, 22.80919744, 25.08472469,
       27.77894894, 30.39125422, 25.67636177, 21.09800095, 20.02395209,
       26.113263  , 24.93909094, 18.03059022, 23.08226071, 29.41732856,
       27.86997003, 25.31227741, 24.44757706, 28.88030413, 31.19223981,
       25.54893224, 32.86702786, 27.66972364, 25.72187231, 19.68717406,
       10.59416719, 21.05249041, 20.15138162, 22.3631941 , 25.1029289 ,
       17.25691096, 19.15925174, 17.95777335, 23.41903874, 20.97057143,
       23.81953154, 23.36442609, 20.31521958, 17.28421729, 23.71940834,
       23.86504208, 22.78189111, 20.69750816, 18.74055473, 22.9730354 ,
       21.2527368 , 17.26601307, 20.22419849, 22.81829955, 22.76368689,
       20.27881114, 18.74965683, 18.98631167, 20.47905754, 19.80550148,
       19.65076562, 31.23775036, 24.85717196, 26.27710096, 27.89727636,
       20.06946264, 19.01361799, 24.63872134, 25.72187231, 28.48891344,
       24.40206651, 25.21215421, 18.88618847, 26.56836845, 16.87462238,
       19.35949814, 21.87168021, 23.53736616, 21.09800095, 20.96146932,
       23.56467249, 22.22666246, 14.13488758, 18.14891764, 45.24589608,
       -2.25801069, 10.5031461 ,  0.49082622, 10.56686086, 26.15877354,
       29.18977584, 21.90808865, 18.80426949,  9.98432589,  2.99390619,
       31.8931022 , 25.84930184, 27.16910764, 23.40083452, 21.97180341,
       28.7528746 , 24.90268251, 15.71865454, 15.5730208 ,  5.08739125,
       13.36120832,  7.6723902 , 10.83992413,  9.74767105, 14.38974663,
       17.32972783, 20.40624067, 11.16760005, 21.69874014, 18.9134948 ,
       24.22912644, 23.62838725, 17.63919954, 14.9631795 , 18.59492098,
       19.82370569, 23.06405649, 23.61928514, 14.01656016, 15.673144  ,
       17.05666456,  2.99390619, 16.37400639, 16.45592537, 27.69702996,
       17.73022063, 25.92211871,  7.45393959, 12.25075102,  6.46180971,
       23.89234841, 27.05988233, 13.60696526, 19.55064242, 27.44217091,
       23.6829999 , 19.99664576, 16.73809075, 20.87955034, 15.9826157 ,
       18.99541378, 18.45838935, 21.78065912, 21.69874014, 23.40083452,
       23.10956704, 27.52408989, 23.81042943, 23.91055263, 21.83527178,
       25.66725966, 24.13810535, 21.32555368, 19.35039603, 16.54694646,
       18.28544928, 23.63748936, 21.93539498, 24.35655597, 18.6131252 ,
       24.11990113, 23.04585227, 22.22666246, 21.62592327, 23.73761256,
       26.75951274, 25.90391449, 22.64535948, 32.62127092, 26.56836845,
       24.72064033, 19.7235825 , 19.35949814, 22.68176791, 20.67930394,
       26.32261151, 23.36442609, 22.82740166, 24.61141502, 21.84437389,
       17.74842485, 19.50513188, 19.96933944, 19.26847705, 17.32972783,
       21.46208531, 22.02641607, 23.91965474, 28.86209991, 14.72652466,
       21.41657477, 24.34745386, 13.60696526, 21.62592327, 22.02641607,
       22.14474348, 26.76861485, 29.59937074, 17.77573117, 18.76786105,
       22.78189111, 20.97967353, 19.07733276, 14.97228161, 14.60819725,
       11.68642026, 19.78729726, 19.78729726, 17.27511518, 19.26847705,
       16.93833715, 14.38974663, 18.06699866, 20.11497318, 16.01902414,
       20.18779005, 25.33958374, 21.03428619, 28.82569148, 27.16910764,
       20.21509638])
#2번째 방법
pred=fitted_model1.predict(train_data)
print(pred)
0      25.175746
1      23.774021
2      30.728032
3      29.025938
4      30.382152
         ...    
501    25.339584
502    21.034286
503    28.825691
504    27.169108
505    20.215096
Length: 506, dtype: float64
pred - (np.dot(train_data,fitted_model1.params)) #확인
0      0.0
1      0.0
2      0.0
3      0.0
4      0.0
      ... 
501    0.0
502    0.0
503    0.0
504    0.0
505    0.0
Length: 506, dtype: float64

잔차의 합

sum(fitted_model1.resid)
-4.300559908188006e-12

추정한 선형 모델과 실제값 시각화하기

import matplotlib.pyplot as plt
plt.scatter(x,target,label="y")
plt.plot(x,pred,label="y_hat")
plt.legend()
plt.show()

{: width="60%" height="60%"}

profile
study blog
post-custom-banner

0개의 댓글