The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

boingboing·2023년 9월 18일
0

현상


def relu_backward(dout, cache):
    """
    Computes the backward pass for a layer of rectified linear units (ReLUs).

    Input:
    - dout: Upstream derivatives, of any shape
    - cache: Input x, of same shape as dout

    Returns:
    - dx: Gradient with respect to x
    """
    dx, x = None, cache 
    print("x, shape:", x.shape) # shape of x : 10, 10
    
    if(x < 0):
        dx = 0
    else:
        dx = 1 
        
    return dx

원인

  • x<0 이 부분에서 발생.
  • array와 scalar를 저렇게 비교할 수 없음.
  • element wise인지 array 전체 대상인지 명시해 줘야 함.
import numpy as np

a = np.array([1, 2, 3, 4])
b = np.array([1, 2, 3, 4])

if (a==b):
  print("a와 b가 같음")

-> a==b의 결과
# array([True, True, True, True])
  • numpy 배열 자체가 같은지 아닌지를 보는 게 아니라
    numpy array 2개 == -> 각 위치의 자료가 일치하는지 여부를 리턴함.
    배열 자체의 비교인지, element끼리 비교인지 불확실
    ->

if (a==b).all():
  print("a와 b가 같습니다!")
  # array 내의 모든 값들이 같음->전부 True-> 위 메세지가 출력 됨1개라도 True인지 체크하고 싶다면 any 

해결

    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            cur_elem = dout[i, j]

            if(cur_elem<0):
                dx[i, j] = 0
            else:
                dx[i, j] = 1

0개의 댓글