axis 이해하기
- 몇몇 함수에는 axis keyword 파라미터가 존재
- axis값이 없는 경우에는 전체 데이터에 대해 적용
- axis값이 있는 경우에는, 해당 axis를 따라서 연산 적용
- 거의 대부분의 연산 함수들이 axis 파라미터를 사용
- 이 경우, 해당 값이 주어졌을 때, 해당 axis를 따라서 연산이 적용
- 따라서 결과는 해당 axis가 제외된 나머지 차원의 데이터만 남게 됨
- 예) np.sum, np.mean, np.any 등등
- axis는 열의 증가방향이 기준이라고 생각하면 된다. 1차원일때는 열의 증가방향박에 없고 2차원 행렬일 경우 열의 증가방향의 axis=1, 행의 증가방향이 axis=2 삼차원 텐서의 경우 열의 증가방향이 axis=2. 행의증가뱡향 axis=1, 깊이의 증가방향 axis=0. 즉, 차원이 하나씩 증가할수록 열의 증가방향에 대한 axis값이 1씩 증가한다고 생각하면 된다.
import numpy as np
y = np.arange(9).reshape(3, 3)
print(y)
[[0 1 2]
[3 4 5]
[6 7 8]]
np.sum(y, axis=0)
array([ 9, 12, 15])
np.sum(y, axis=1)
array([ 3, 12, 21])
z = np.arange(24).reshape(2, 3, 4)
print(z)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
np.sum(z, axis=0)
array([[12, 14, 16, 18],
[20, 22, 24, 26],
[28, 30, 32, 34]])
np.sum(z, axis=1)
array([[12, 15, 18, 21],
[48, 51, 54, 57]])
np.sum(z, axis=2)
array([[ 6, 22, 38],
[54, 70, 86]])
- 해당 튜플에 명시된 모든 axis에 대해서 연산
np.sum(z, axis=(0, 1))
array([60, 66, 72, 78])