CNN Model 구현

예갈조·2024년 12월 10일

Tumor Track Project

목록 보기
14/25

CNN 개념 설명


CNN



레이어 크기 계산


CNN Layer 별 출력 크기 계산하기



코드 요약


1. 클래스 생성자 CNNModel::CNNModel()

  1. conv Layer: CNN 특징 추출 부분. 4개의 합성곱 층과 풀링 층으로 구성
    1. 각 층의 구성
      1. Conv2d: 입력 데이터를 필터로 합성곱해 특징을 추출
      2. BatchNorm2d: 학습 안정화를 위해 배치 정규화를 수행
      3. ReLU: 비선형성을 추가하여 모델이 복잡한 패턴을 학습할 수 있도록 함
      4. MaxPool2d: 공간적 크기를 축소(샘플링)하여 계산량을 줄이고 중요한 특징만 남김
  2. FC Layer: 합성곱 레이어를 통과한 출력 데이터를 분류 작업에 사용하기 위해 Flatten한 뒤, Fully Conncected 레이어를 적용함
    1. 구성
      1. Linear(128 9 9, 512): 128개의 9×9 크기 특징 맵을 512차원 벡터로 변환
      2. ReLU: 비선형성을 추가
      3. Linear(512, 2): 2개의 클래스를 분류하는 이진 분류 레이어
  3. register_module 함수를 사용하여 conv와 fc를 모델의 서브모듈로 등록
    1. Python 환경에서는 self.convself.fc로 서브모듈을 정의하면 자동으로 torch.nn.Module에 등록되었는데 C++ LibTorch에서는 register_module(”name”, module)을 명시적으로 호출해주어야 한다.
    2. 모듈을 등록한 후 파라미터를 저장 및 로드할 수 있다.
      1. Python: torch.savetorch.load로 모델 저장 및 로드
      2. LibTorch: torch::savetorch::load로 모델 저장 및 로드

2. 순전파 함수 CNNModel::forward

입력 데이터를 처리하는 순전파(forward pass) 과정을 정의

  1. 텐서 차원 제거
    • 입력 데이터의 불필요한 차원 제거: x.squeeze(1)
  2. 합성곱 레이어 통과
    • conv 모듈을 사용해 데이터를 처리
  3. Flatten
    • 합성곱 결과를 1D 벡터로 변환: x.view({x.size(0), -1})
  4. 완전 연결 레이어 통과
    • fc 모듈을 통해 데이터를 최종 분류
    • ReLU(활성화함수)와 함께 특징을 변환하고 최종적으로 2개의 출력값을 생성
  5. 출력 반환
    • 최종 결과(예: 클래스 확률)

→ 해당 모델은 이진 분류 문제를 처리하도록 설계되어 있으나, Fully Connected Layer 출력 크기를 수정하면 다중 분류 문제에도 사용 가능



특징


  • 4단계 합성곱 블록: 점점 더 복잡한 특징을 추출.
  • BatchNorm + ReLU: 학습 안정성과 비선형성 추가.
  • MaxPooling: 공간 크기 축소 및 특징 요약.
  • Fully Connected Layer: 추출된 특징을 기반으로 최종 분류.



코드


  • 해당 코드는 PyTorch C++ 라이브러리(LibTorch)를 사용해 정의된 CNN(Convolutional Neural Network) 모델
  • 이미지 데이터를 입력받아 분류 작업을 수행하는 신경망 설계하기 위함
#include "model.h"

CNNModel::CNNModel() {
	// Convolutional layers
	conv = torch::nn::Sequential(
		// Conv1
		torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 32, 3).stride(1).padding(1)),  // Conv_1
		torch::nn::BatchNorm2d(32),                                                 // BatchNorm_1
		torch::nn::ReLU(),                                                          // ReLU_1
		torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)),             // MaxPool_1

		// Conv2
		torch::nn::Conv2d(torch::nn::Conv2dOptions(32, 64, 3).stride(1).padding(1)),// Conv_2
		torch::nn::BatchNorm2d(64),                                                 // BatchNorm_2
		torch::nn::ReLU(),                                                          // ReLU_2
		torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)),             // MaxPool_2

		// Conv3
		torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 128, 3).stride(1).padding(1)),// Conv_3
		torch::nn::BatchNorm2d(128),                                                // BatchNorm_3
		torch::nn::ReLU(),                                                          // ReLU_3
		torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)),             // MaxPool_3

		// Conv4
		torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 128, 3).stride(1).padding(1)),// Conv_4
		torch::nn::BatchNorm2d(128),                                                // BatchNorm_4
		torch::nn::ReLU(),                                                          // ReLU_4
		torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(3))              // MaxPool_4
	);

	// Fully connected layers
	fc = torch::nn::Sequential(
		torch::nn::Linear(128 * 11 * 11, 512),  // Flatten된 입력 크기에 맞게 수정
		torch::nn::ReLU(),                    // ReLU_5
		torch::nn::Linear(512, 2)             // Fully Connected 2 -> 다중분류
	);

	// Register modules
	register_module("conv", conv);
	register_module("fc", fc);
}

torch::Tensor CNNModel::forward(torch::Tensor x) {
	//auto asd = x.get_device();
	//std::cout << "Input shape: " << x.sizes() << std::endl;
	x = x.squeeze(1);  // Remove the time dimension

	x = conv->forward(x);
	//std::cout << "After conv: " << x.sizes() << std::endl;

	// Flatten
	x = x.contiguous().view({ x.size(0), -1 });
	//std::cout << "After flatten: " << x.sizes() << std::endl;

	x = fc->forward(x);
	//std::cout << "Output shape: " << x.sizes() << std::endl;

	return x;
}





추가 공부

  • 배치 정규화

참고자료

위키독스nn.Conv2d

0개의 댓글