CNNModel::CNNModel()conv Layer: CNN 특징 추출 부분. 4개의 합성곱 층과 풀링 층으로 구성FC Layer: 합성곱 레이어를 통과한 출력 데이터를 분류 작업에 사용하기 위해 Flatten한 뒤, Fully Conncected 레이어를 적용함register_module 함수를 사용하여 conv와 fc를 모델의 서브모듈로 등록self.conv와 self.fc로 서브모듈을 정의하면 자동으로 torch.nn.Module에 등록되었는데 C++ LibTorch에서는 register_module(”name”, module)을 명시적으로 호출해주어야 한다.torch.save와 torch.load로 모델 저장 및 로드torch::save와 torch::load로 모델 저장 및 로드CNNModel::forward입력 데이터를 처리하는 순전파(forward pass) 과정을 정의
x.squeeze(1)conv 모듈을 사용해 데이터를 처리x.view({x.size(0), -1})fc 모듈을 통해 데이터를 최종 분류→ 해당 모델은 이진 분류 문제를 처리하도록 설계되어 있으나, Fully Connected Layer 출력 크기를 수정하면 다중 분류 문제에도 사용 가능
#include "model.h"
CNNModel::CNNModel() {
// Convolutional layers
conv = torch::nn::Sequential(
// Conv1
torch::nn::Conv2d(torch::nn::Conv2dOptions(3, 32, 3).stride(1).padding(1)), // Conv_1
torch::nn::BatchNorm2d(32), // BatchNorm_1
torch::nn::ReLU(), // ReLU_1
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)), // MaxPool_1
// Conv2
torch::nn::Conv2d(torch::nn::Conv2dOptions(32, 64, 3).stride(1).padding(1)),// Conv_2
torch::nn::BatchNorm2d(64), // BatchNorm_2
torch::nn::ReLU(), // ReLU_2
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)), // MaxPool_2
// Conv3
torch::nn::Conv2d(torch::nn::Conv2dOptions(64, 128, 3).stride(1).padding(1)),// Conv_3
torch::nn::BatchNorm2d(128), // BatchNorm_3
torch::nn::ReLU(), // ReLU_3
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(2)), // MaxPool_3
// Conv4
torch::nn::Conv2d(torch::nn::Conv2dOptions(128, 128, 3).stride(1).padding(1)),// Conv_4
torch::nn::BatchNorm2d(128), // BatchNorm_4
torch::nn::ReLU(), // ReLU_4
torch::nn::MaxPool2d(torch::nn::MaxPool2dOptions(2).stride(3)) // MaxPool_4
);
// Fully connected layers
fc = torch::nn::Sequential(
torch::nn::Linear(128 * 11 * 11, 512), // Flatten된 입력 크기에 맞게 수정
torch::nn::ReLU(), // ReLU_5
torch::nn::Linear(512, 2) // Fully Connected 2 -> 다중분류
);
// Register modules
register_module("conv", conv);
register_module("fc", fc);
}
torch::Tensor CNNModel::forward(torch::Tensor x) {
//auto asd = x.get_device();
//std::cout << "Input shape: " << x.sizes() << std::endl;
x = x.squeeze(1); // Remove the time dimension
x = conv->forward(x);
//std::cout << "After conv: " << x.sizes() << std::endl;
// Flatten
x = x.contiguous().view({ x.size(0), -1 });
//std::cout << "After flatten: " << x.sizes() << std::endl;
x = fc->forward(x);
//std::cout << "Output shape: " << x.sizes() << std::endl;
return x;
}