10/4 주요 코드이전 실습
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
iris = load_iris()
feature_names = iris.feature_names
n_feature = len(feature_names)
iris_data = iris.data
sepal_length = iris.data[:, 0]
sepal_width = iris.data[:, 1]
petal_length = iris.data[:, 2]
petal_width = iris.data[:, 3]
target = iris.target
np.random.seed(0)
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
setosa_sepal_length = sepal_length[target == 0]
versicolor_sepal_length = sepal_length[target == 1]
virginica_sepal_length = sepal_length[target == 2]
setosa_sepal_width = sepal_width[target == 0]
versicolor_sepal_width = sepal_width[target == 1]
virginica_sepal_width = sepal_width[target == 2]
setosa_petal_length = petal_length[target == 0]
versicolor_petal_length = petal_length[target == 1]
virginica_petal_length = petal_length[target == 2]
setosa_petal_width = petal_width[target == 0]
versicolor_petal_width = petal_width[target == 1]
virginica_petal_width = petal_width[target == 2]
axes[0, 0].violinplot([setosa_sepal_length, versicolor_sepal_length, virginica_sepal_length])
axes[0, 0].set_xticks([1, 2, 3])
axes[0, 0].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[0, 0].set_title('Sepal Length (cm)')
axes[0, 1].violinplot([setosa_sepal_width, versicolor_sepal_width, virginica_sepal_width])
axes[0, 1].set_xticks([1, 2, 3])
axes[0, 1].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[0, 1].set_title('Sepal Width (cm)')
axes[1, 0].violinplot([setosa_petal_length, versicolor_petal_length, virginica_petal_length])
axes[1, 0].set_xticks([1, 2, 3])
axes[1, 0].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[1, 0].set_title('Petal Length (cm)')
axes[1, 1].violinplot([setosa_petal_width, versicolor_petal_width, virginica_petal_width])
axes[1, 1].set_xticks([1, 2, 3])
axes[1, 1].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[1, 1].set_title('Petal Width (cm)')
plt.show()
for 문을 쓰지 못하고 무식하게 옮겨 적기만 했다. 이 부분이 참 아쉬운데, 강사 님이 enumerate 함수를 잘 사용하면 될거라고 말했지만, 사실 어떻게 하는지 감이 잡히지 않아 복습겸 적어보는 내용.
arange함수 알아보기
그 전에 arange 함수가 궁금해서 다시 찾아봤다. 모르는 게 너무 많아..
https://www.youtube.com/watch?v=lsuFr1L0hEk
위 강의를 참조했더니 아주 좋은 이해가 되었다.
arange = array range. 즉, 배열의 범위를 설정해주는 함수라는 걸 알 수 있다.
들어가는 숫자에 따라서 의미가 달라지는데 한 번 같이 알아보자.
np.arange에서 숫자가 하나만 들어가면 python의 range와 동일하게 생각하면 된다. ~까지
숫자 두 개가 나온다면? 시작하고 끝나는 지점을 설정해준다 생각하면 된다.
세 개가 나온다면 어떨까. 시작 , 끝, 단위를 설정해 준다고 생각하면 되는 것이다.
a의 출력값: 0,1,2,3,4
b의 출력값: 1,3,5,7, 9
c의 출력값: 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 ,1.7, 1.8, 1.9
그렇다면 활용은 어떻게 할까? 아래를 보자!
import matplotlib.pyplot as plt
x = np.arange(0, 10, 0.1)
y = np.sin(x)
plt.plot(x, y)
plt.show()
x축을 0부터 10까지 0.1간격으로 원소를 가지는 배열로 생성했고, y축은 x의 sin값으로 지정해주었다.
결과물은 어떻게 나올까?
위처럼 출력이 되는 것을 알 수 있다. 0.1간격이 보이지 않겠지만, 이렇게 보면 0.1로 간격이 되었음을 알 수 있다.
import matplotlib.pyplot as plt
x = np.arange(0, 10, 2)
y = np.sin(x)
plt.plot(x, y)
plt.show()
단위를 2로 하니 sin그래프가 아주 못생기게 나온다. 이렇게 numpy의 arange 함수를 이용해 편리하게 데이터 시각화를 진행할 수 있다.
enumerate
python 메서드 중 하나인 enumerate다.
인덱스의 값과 원소의 값을 두 가지다 확인해 번거로운 작업을 피할 수 있게 해주는 메서드.
이때 출력값은 tuple형태로 원소를 반환한다는 것을 유의하자.
여기서 for 문을 돌려보면 이해가 빠르다.
쉬운 형태를 보자면 아래와 같다.
>>> t = [1, 5, 7, 33, 39, 52]
>>> for p in enumerate(t):
... print(p)
...
(0, 1)
(1, 5)
(2, 7)
(3, 33)
(4, 39)
(5, 52)
for loop에서 사용한다면 for p in enumerate(t)라는 형식으로 인덱스와 밸류를 추출해낼 수 있다.
내가 문제에 도달했던 코드를 살펴보자.
import matplotlib.pyplot as plt
import numpy as np
PI = np.pi
#reshape: shape를 바꿔준다. (1000, ) 에서 (1,1000)으로 바꿔줌
#질문해야 할듯?
t = np.linspace(-4*PI, 4*PI, 1000).reshape(1, -1)
sin = np.sin(t)
cos = np.cos(t)
tan = np.tan(t)
#각각의 데이터들을 차례대로 접근해 수정하기 위해 행렬의 형태로 바꿔주는 것
#(3,1000)짜리 행렬이 만들어지게 됨
#vertical stack: 수직방향으로 쌓아준다. cf) hstack: horizonal
data = np.vstack((sin, cos, tan))
#결과적으로 matrix의 행렬은 (3,1000)형태가 될듯
title_list = [r'$sin(t)$', r'$cos(t)$', r'$tan(t)$']
x_ticks = np.arange(-4*PI, 4*PI+PI, PI)
x_ticklabels = [str(i) + r'$\pi$' for i in range(-4,5)]
fig, axes = plt.subplots(3, 1,
figsize=(7, 10),
sharex=True)
for ax_idx, ax in enumerate(axes.flat):
#flatten을 통해 vectorization을 해주는 것
ax.plot(t.flatten(), data[ax_idx])
ax.set_title(title_list[ax_idx],
fontsize=30)
ax.tick_params(labelsize=20)
ax.grid()
if ax_idx == 2:
ax.set_ylim([-3, 3])
fig.subplots_adjust(left=0.1, right=0.95,
bottom=0.05, top=0.95)
axes[-1].set_xticks(x_ticks)
axes[-1].set_xticklabels(x_ticklabels)
plt.show()