Artificial Intelligence #12 Linear Regression 실습

김서영·2024년 12월 15일

인공지능

목록 보기
11/13
post-thumbnail

1. Linear Regression Base code

data와 target 분리

from sklearn.datasets import fetch_california_housing

dataset=fetch_california_housing()

data=dataset.data
target=dataset.target

train, test 분리

from sklearn.model_selection import train_test_split

train_data, test_data, train_target, test_target=train_test_split(data, target, test_size=0.2, random_state=42)

print(f'train data shape : {train_data.shape}')
print(f'test data shape : {test_data.shape}')

모델 학습

from sklearn.linear_model import LinearRegression

model=LinearRegression()
model.fit(train_data, train_target)

print(model.score(test_data, test_target))

성능 평가

from sklearn.metrics import mean_squared_error, r2_score

test_pred=model.predict(test_data)
print(mean_squared_error(test_target, test_pred))
print(r2_score(test_target, test_pred))

2. Regularization 실습

Library 불러오기

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge, Lasso, LinearRegression
from sklearn.metrics import mean_squared_error

데이터 생성하기

np.random.seed(42)
data=np.random.rand(100,5)
true_coefficients=np.array([1.5,-2.0,0.0,0.0,3.0]) #Real Weight
target=data@true_coefficients+np.random.randn(100)*0.5

train_data, test_data, train_target, test_target=train_test_split(data, target, test_size=0.2, random_state=42)

모델 정의하기

#모델 정의
models={
    "Linear Regression":LinearRegression(),
    "Ridge":Ridge(alpha=1.0),
    "Lasso":Lasso(alpha=0.1)
}


#결과 저장용
coefficients={}
mse={}

학습 및 테스트

for name, model in models.items():
  model.fit(train_data, train_target)
  test_pred=model.predict(test_data)
  coefficients[name]=model.coef_
  mse[name]=mean_squared_error(test_target, test_pred)

시각화

plt.figure(figsize=(10,6))
bar_width=0.2
x_indices=np.arange(len(true_coefficients))

for i, (name, coef) in enumerate(coefficients.items()):
  plt.bar(x_indices+i*bar_width, coef, bar_width, label=name)

plt.axhline(0, color="black", linestyle="--", linewidth=0.8)
plt.xticks(x_indices+bar_width, [f"Feature {i+1}" for i in range(len(true_coefficients))])
plt.xlabel("Features")
plt.ylabel("Coefficient Value")
plt.title("Comparison of Coefficients")
plt.legend()
plt.show()

for name, error in mse.items():
  print(f"{name}: MSE = {error:.4f}")

profile
안녕하세요 :)

0개의 댓글