CNN Visualization

hu22·2023년 10월 8일
1
post-custom-banner

1.Visualizing CNN

1-1.What is CNN visualization?

  • CNN은 학습가능한 convolution과 비선형 활성화 함수로 이루어진 연산기로 볼 수 있음.
  • 학습을 시키다보면 인간 이상의 성능을 보여주기도 하고, 잘 되지 않는 경우도 있지만 CNN의 내부는 블랙박스 구조에 가깝기 때문에 왜 잘되고 왜 잘 안되는지 직접적으로 판단하기 어려움.
  • 이러한 문제를 해결하기 위해 CNN을 visualization하는 기술이 등장.
  • Visualization tool이 debugging tool처럼 작용함.

Example

  • CNN layer들이 각 위치에 따라서 어떤 것을 학습했는지를 convolution의 역연산인 deconvolution을 이용해 visualization한 연구가 있었음.
  • 그 결과, 낮은 계층에서는 낮은 layer에서는 가로/세로 무늬, 물체의 형태같은 구체적인 정보를 담고있는 필터를 확인할 수 있었고, 높은 layer로 갈수록 추상적이고 의미론적인 필터를 확인할 수 있었음.

1-2. Fiter visualization (가장 간단한 visualization 기법)

  • AlexNet에 filter weight visualization을 적용해보면 첫 convolution layer의 filter의 channel이 3이므로 컬러이미지로 시각화할 수 있음.
  • 그 결과 컬러, 각도, 블록 등 다양한 정보를 학습한 filter를 확인할 수 있음.
  • Activation을 취하면 Channel size가 1이 되므로, Activation을 취한 뒤의 결과도 흑백 이미지로도 visualization을 할 수 있음.
  • 이를 통해 CNN의 필터들이 어떤 정보를 집중적으로 학습했는지 확인할 수 있음.

1-3. How to Visualize Neral Network

  • 크게 model자체의 특성을 분석하는데 중점을 둔 방법과 어떻게 output이 나오게 되었는지에 중점을 둬서 분석하는 방법, 두 가지로 나눌 수 있음.

2. Analysis of model behaviors

  • model자체의 특성을 분석하는데 중점을 둔 방법.

2-1. Embedding feature analysis

Nearest neighbors in a feature analysis

  • query image를 입력하면 DB내에서 가장 유사한 이미지를 찾는 방식.
  • 예를 들어 코끼리 이미지를 입력했는데 유사한 이미지가 나오는 것을 보아, embedding space내에서 의미론적으로 잘 clustering을 이루고 있는 것을 확인할 수 있음.
  • 파란박스의 이미지를 보면 강아지의 위치와 자세가 query image와 다르므로 벡터 간의 거리 차가 있음을 알 수 있지만 NN을 했을 때 강아지 이미지가 나오는 것을 보면 의미론적으로 잘 clustering 됐음을 확인할 수 있음.
  • 원래 Neural Net에서는 embedding vector를 기반으로 마지막 classification task를 진행하지만 visualization에서는 embedding vector 그 자체를 DB의 이미지와 매칭시킴.

Dimensionality reduction

  • embedding vector는 사실 너무 고차원이라 인간이 해석하기 힘듦. 그래서 feature space의 차원을 축소해야함.
  • 대표적인 방법으로 t-SNE가 있음.(추후 공부)

2-2. Activation investigation

Maximally activation patches

  • 이미지를 입력해서 특정layer의 activation map을구하고 이 맵에서 가장 큰 값을 가지는 patch의 위치 정보를 저장. 이후 이 위치에 해당하는 receptive field의 해당하는 이미지를 확인해보면 그 특정 layer가 어떤 것을 중요로 하는지 확인할 수 있음.

Class visualization

  • activation을 분석하기 위해 데이터에서 추출한 특정 이미지를 활용한 위의 방법과 달리 데이터를 사용하지 않고 네트워크의 안에 있는 정보를 시각화.
  • 위의 그림을 보면 "새" class로 분류하기 위해 나뭇가지의 형상을 중요시 한 것을 볼 수 있음.
  • Gradient Ascent 기법을 사용함.

    -> Gradient Ascent에서 최적화에 사용하는 Loss function
    (II는 image, f(I)f(I)는 이미지를 CNN에 입력해주었을 때 출력된 하나의 class score)
    이미지에 대해 클래스 score를 예측-> loss function에 따라 backpropagation 하여 클래스 스코어를 최대화하도록 업데이트-> 업데이트 된 이미지를 input으로 하여 반복

profile
ai 개발자를 꿈꾸는 대학생
post-custom-banner

0개의 댓글