[numpy/matplotlib] arange / enumerate 연습

Pygmalion Dali·2023년 10월 5일
0

matplotlib

목록 보기
3/3
post-thumbnail

이전 실습

10/4 주요 코드
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

iris = load_iris()

feature_names = iris.feature_names
n_feature = len(feature_names)

iris_data = iris.data

sepal_length = iris.data[:, 0]
sepal_width = iris.data[:, 1]
petal_length = iris.data[:, 2]
petal_width = iris.data[:, 3]

target = iris.target

np.random.seed(0)
fig, axes = plt.subplots(2, 2, figsize=(12, 12))

setosa_sepal_length = sepal_length[target == 0]
versicolor_sepal_length = sepal_length[target == 1]
virginica_sepal_length = sepal_length[target == 2]

setosa_sepal_width = sepal_width[target == 0]
versicolor_sepal_width = sepal_width[target == 1]
virginica_sepal_width = sepal_width[target == 2]

setosa_petal_length = petal_length[target == 0]
versicolor_petal_length = petal_length[target == 1]
virginica_petal_length = petal_length[target == 2]

setosa_petal_width = petal_width[target == 0]
versicolor_petal_width = petal_width[target == 1]
virginica_petal_width = petal_width[target == 2]

axes[0, 0].violinplot([setosa_sepal_length, versicolor_sepal_length, virginica_sepal_length])
axes[0, 0].set_xticks([1, 2, 3])
axes[0, 0].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[0, 0].set_title('Sepal Length (cm)')

axes[0, 1].violinplot([setosa_sepal_width, versicolor_sepal_width, virginica_sepal_width])
axes[0, 1].set_xticks([1, 2, 3])
axes[0, 1].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[0, 1].set_title('Sepal Width (cm)')

axes[1, 0].violinplot([setosa_petal_length, versicolor_petal_length, virginica_petal_length])
axes[1, 0].set_xticks([1, 2, 3])
axes[1, 0].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[1, 0].set_title('Petal Length (cm)')

axes[1, 1].violinplot([setosa_petal_width, versicolor_petal_width, virginica_petal_width])
axes[1, 1].set_xticks([1, 2, 3])
axes[1, 1].set_xticklabels(['setosa', 'versicolor', 'virginica'])
axes[1, 1].set_title('Petal Width (cm)')

plt.show()

for 문을 쓰지 못하고 무식하게 옮겨 적기만 했다. 이 부분이 참 아쉬운데, 강사 님이 enumerate 함수를 잘 사용하면 될거라고 말했지만, 사실 어떻게 하는지 감이 잡히지 않아 복습겸 적어보는 내용.

arange함수 알아보기

그 전에 arange 함수가 궁금해서 다시 찾아봤다. 모르는 게 너무 많아..

https://www.youtube.com/watch?v=lsuFr1L0hEk

위 강의를 참조했더니 아주 좋은 이해가 되었다.

arange = array range. 즉, 배열의 범위를 설정해주는 함수라는 걸 알 수 있다.

들어가는 숫자에 따라서 의미가 달라지는데 한 번 같이 알아보자.

np.arange에서 숫자가 하나만 들어가면 python의 range와 동일하게 생각하면 된다. ~까지

숫자 두 개가 나온다면? 시작하고 끝나는 지점을 설정해준다 생각하면 된다.

세 개가 나온다면 어떨까. 시작 , 끝, 단위를 설정해 준다고 생각하면 되는 것이다.

a의 출력값: 0,1,2,3,4
b의 출력값: 1,3,5,7, 9
c의 출력값: 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 ,1.7, 1.8, 1.9

그렇다면 활용은 어떻게 할까? 아래를 보자!

import matplotlib.pyplot as plt

x = np.arange(0, 10, 0.1)
y = np.sin(x)
plt.plot(x, y)
plt.show()

x축을 0부터 10까지 0.1간격으로 원소를 가지는 배열로 생성했고, y축은 x의 sin값으로 지정해주었다.

결과물은 어떻게 나올까?

위처럼 출력이 되는 것을 알 수 있다. 0.1간격이 보이지 않겠지만, 이렇게 보면 0.1로 간격이 되었음을 알 수 있다.

import matplotlib.pyplot as plt

x = np.arange(0, 10, 2)
y = np.sin(x)
plt.plot(x, y)
plt.show()

단위를 2로 하니 sin그래프가 아주 못생기게 나온다. 이렇게 numpy의 arange 함수를 이용해 편리하게 데이터 시각화를 진행할 수 있다.

enumerate

python 메서드 중 하나인 enumerate다.

인덱스의 값과 원소의 값을 두 가지다 확인해 번거로운 작업을 피할 수 있게 해주는 메서드.

이때 출력값은 tuple형태로 원소를 반환한다는 것을 유의하자.

여기서 for 문을 돌려보면 이해가 빠르다.

쉬운 형태를 보자면 아래와 같다.

>>> t = [1, 5, 7, 33, 39, 52]
>>> for p in enumerate(t):
...     print(p)
... 
(0, 1)
(1, 5)
(2, 7)
(3, 33)
(4, 39)
(5, 52)

for loop에서 사용한다면 for p in enumerate(t)라는 형식으로 인덱스와 밸류를 추출해낼 수 있다.

내가 문제에 도달했던 코드를 살펴보자.

import matplotlib.pyplot as plt
import numpy as np

PI = np.pi
#reshape: shape를 바꿔준다. (1000, ) 에서 (1,1000)으로 바꿔줌
#질문해야 할듯?
t = np.linspace(-4*PI, 4*PI, 1000).reshape(1, -1)
sin = np.sin(t)
cos = np.cos(t)
tan = np.tan(t)

#각각의 데이터들을 차례대로 접근해 수정하기 위해 행렬의 형태로 바꿔주는 것
#(3,1000)짜리 행렬이 만들어지게 됨
#vertical stack: 수직방향으로 쌓아준다. cf) hstack: horizonal
data = np.vstack((sin, cos, tan))

#결과적으로 matrix의 행렬은 (3,1000)형태가 될듯
title_list = [r'$sin(t)$', r'$cos(t)$', r'$tan(t)$']
x_ticks = np.arange(-4*PI, 4*PI+PI, PI)
x_ticklabels = [str(i) + r'$\pi$' for i in range(-4,5)]
fig, axes = plt.subplots(3, 1,
                         figsize=(7, 10),
                         sharex=True)

for ax_idx, ax in enumerate(axes.flat):
    #flatten을 통해 vectorization을 해주는 것
    ax.plot(t.flatten(), data[ax_idx])
    ax.set_title(title_list[ax_idx],
                 fontsize=30)
    ax.tick_params(labelsize=20)
    ax.grid()
    if ax_idx == 2:
        ax.set_ylim([-3, 3])
fig.subplots_adjust(left=0.1, right=0.95,
                    bottom=0.05, top=0.95)
axes[-1].set_xticks(x_ticks)
axes[-1].set_xticklabels(x_ticklabels)
plt.show()

업로드중..

0개의 댓글