[NumPy] Axis

kkiyou·2021년 6월 8일
0

Data Science

목록 보기
3/11
post-custom-banner

axis는 축을 의미한다. 그리고 통상적으로 N차원 그래프는 아래와 같은 축을 가진다.

  • 1차원 그래프는 x축
  • 2차원 그래프는 y축, x축
  • 3차원 그래프는 z축, y축, x축


그리고 이 축은 위의 그림처럼 좌표평면에 구현할 수 있다. 1차원 그래프는 직선 상의 모음이고, 2차원 그래프는 수학을 배울 때 접했다.
반면 3차원 그래프는 직관적이지는 않다. 이로 인해 NumPy에서 3차원 배열의 연산을 이해하기 어렵다. 특히 위에 사진처럼 y축은 양수 값을 가지는 것이 아니라, 음수 값을 가진 형태로 반환된다. 따라서 3차원 배열을 시각화한 연산 과정을 서술한다.


우선 NumPy로 3차원 배열을 생성해보자.

import numpy as np
arr_3d = np.array([[[1, 3, 5, 7],
                   [2, 4, 6, 8]],
                 
                  [[10, 30, 50, 70],
                   [20, 40, 60, 80]],
                 
                  [[100, 300, 500, 700],
                   [200, 400, 600, 800]]])
>>> print(arr_3d)
[[[  1   3   5   7]
  [  2   4   6   8]]

 [[ 10  30  50  70]
  [ 20  40  60  80]]

 [[100 300 500 700]
  [200 400 600 800]]]

>>> arr_3d.shape
(3, 2, 4)

(z=3,y=2,x=4)(z = 3, y = 2, x = 4)인 배열이 만들어졌다. 이 배열은 아래 그림과 같다.

(3, 2, 4)

그리고 axis는 왼쪽에서 오른쪽으로 진행되는 숫자가 순서대로 가장 멀리서부터 안쪽으로의 tuple의 개수이다. 즉 가장 바깥쪽 tuple에 3개로 구분된 데이터가 들어있고, 중간에 2개로 구분된 데이터가 들어있으며, 가장 안쪽에 4개로 구분된 데이터가 들어있는 것이다. 즉 각 대괄호 안에 들어있는 데이터의 수로 이해할 수 있다.

위에서 언급했듯 y축은 음수 영역에 표현되는 것에 주의하자.


이를 다시 (z, y, z)로 표현하면 다음과 같다.

zz = 0
for z in arr_3d:
    yy = 0
    for y in z:
        xx = 0
        for x in y:
            print("(%d, %d, %d) = %d" %(zz, yy, xx, x))
            xx += 1  
        yy += 1
    zz += 1
(0, 0, 0) = 1
(0, 0, 1) = 3
(0, 0, 2) = 5
(0, 0, 3) = 7
(0, 1, 0) = 2
(0, 1, 1) = 4
(0, 1, 2) = 6
(0, 1, 3) = 8
(1, 0, 0) = 10
(1, 0, 1) = 30
(1, 0, 2) = 50
(1, 0, 3) = 70
(1, 1, 0) = 20
(1, 1, 1) = 40
(1, 1, 2) = 60
(1, 1, 3) = 80
(2, 0, 0) = 100
(2, 0, 1) = 300
(2, 0, 2) = 500
(2, 0, 3) = 700
(2, 1, 0) = 200
(2, 1, 1) = 400
(2, 1, 2) = 600
(2, 1, 3) = 800


axis=0

axis=0일 때 sum을 하면 어떻게 될까?

>>> np.sum(arr_3d, axis=0)
array([[111, 333, 555, 777],
       [222, 444, 666, 888]])

arr[z][y][x]arr[axis=0][axis=1][axis=2]이고 axis=0은 3차원 배열에서 z축을 의미한다. 따라서 np.sum(arr_3d, axis=0)은 z축을 기준으로 합친다는 의미이다.


이를 시각화하면, z축을 제거하며 합쳐진다.

axis=0


이를 다시 array로 표현하면 다음과 같다.
# [
 [[1, 3, 5, 7],
  [2, 4, 6, 8]]
+
 [[10, 30, 50, 70],
  [20, 40, 60, 80]]
+
 [[100, 300, 500, 700],
  [200, 400, 600, 800]]
# ]

z축 대괄호가 사라지며 z[0], z[1], z[2]가 더해진다. 위의 array에서 원래 z축 ,+로 대체된 것을 알 수 있다. 즉 삭제하는 축의 ,+로 바꾸고 계산한다고 생각할 수도 있다.

이를 다시 정리하면 다음과 같다.

[[(1 + 10 + 100), (3 + 30 + 300), (5 + 50 + 500), (7 + 70 + 700)],
  [(2 + 20 + 200), (4 + 40 + 400), (6 + 60 + 600), (8 + 80 + 800)]]


axis=1

axis=1일 때 sum을 하면 다음과 같다.

>>> np.sum(arr_3d, axis=1)
array([[   3,    7,   11,   15],
       [  30,   70,  110,  150],
       [ 300,  700, 1100, 1500]])

arr[z][y][x]arr[axis=0][axis=1][axis=2]이고 axis=1은 3차원 배열에서 y축을 의미한다. 따라서 np.sum(arr_3d, axis=1)은 y축을 기준으로 합친다는 의미이다.


이를 시각화하여 보자.

axis=1

우선 위쪽 그림처럼 y축을 제거하며 합쳐진다. 그런데 이는 우리가 아는 2차원 좌표평면과 다르다. 따라서 우리가 아는 2차원 좌표평면으로 3차원 이동을 하면 아래쪽 그림처럼 변한다.


이를 다시 array로 표현하면 다음과 같다.

[
# [
   [1, 3, 5, 7]
  +[2, 4, 6, 8]
# ]
,

# [
   [10, 30, 50, 70]
  +[20, 40, 60, 80]
# ]
,

# [
   [100, 300, 500, 700]
  +[200, 400, 600, 800]
# ]
]

y축 대괄호가 사라지며 y[0], y[1]이 더해진다. 위의 array에서 원래 y축 ,+로 대체된 것을 알 수있다. 즉 삭제하는 축의 ,를 +로 바꾸고 계산한다고 생각할 수도 있다.

이를 다시 정리하면 다음과 같다.

[[(1 + 2), (3 + 4), (5 + 6), (7 + 8)],
 [(10 + 20), (30 + 40), (50 + 60), (70 + 80)],
 [(100 + 200), (300 + 400), (500 + 600), (700 + 800)]]


axis=2

마지막으로 axis=2일 때 sum을 하면 다음과 같다.

>>> np.sum(arr_3d, axis=1)
array([[  16,   20],
       [ 160,  200],
       [1600, 1700]])

arr[z][y][x]arr[axis=0][axis=1][axis=2]이고 axis=2은 3차원 배열에서 x축을 의미한다. 따라서 np.sum(arr_3d, axis=2)는 x축을 기준으로 합친다는 의미이다.


이를 다시 시각화하여 보자.

axis=2
우선 왼쪽 그림처럼 x축을 제거하며 합쳐진다. 그런데 이는 우리가 아는 2차원 좌표평면과 다르다. 따라서 우리가 아는 2차원 좌표평면으로 3차원 이동을 하면 오른쪽 그림처럼 변한다.


그런데 왜 y축이 x축의 역할을 할까? 원래 배열은 [z][y][x]였다. 그리고 우리는 [x]를 제거했다. 그러면 [z][y] 배열이 된다. 즉, [z]는 기존에 우리가 알던 2차원 그래프에서 y축의 위치에, y는 x축의 위치에 있다. 바로 이 위치에 따라서 결정된 것이다.


이를 다시 array로 표현하면 다음과 같다.

[[
#  [
    1 + 3 + 5 + 7
#  ]
,
#  [
    2 + 4 + 6 + 8
#  ]
 ],

 [
#  [
    10 + 30 + 50 + 70
#  ]
,
#  [
    20 + 40 + 60 + 80
#  ]
 ],

 [
#  [
    100 + 300 + 500 + 700
#  ]
,
#  [
    200 + 400 + 600 + 800
#  ]
 ]]

x축 대괄호가 사라지며 x[0], x[1], x[2], x[3]가 더해진다. 위의 array에서 원래 x축 ,+로 대체된 것을 알 수있다. 즉 삭제하는 축의 ,를 +로 바꾸고 계산한다고 생각할 수도 있다.

이를 다시 정리하면 다음과 같다.

[[(1 + 3 + 5 + 7), (2 + 4 + 6 + 8)],
 [(10 + 30 + 50 + 7), (20 + 40 + 60 + 80)],
 [(100 + 300 + 500 + 700), (200 + 400 + 600 + 800)]]

post-custom-banner

0개의 댓글