C++ | Softmax function 구현하기

Drakk·2021년 10월 13일
0

머신러닝

목록 보기
4/4
post-thumbnail

🔪소개및 개요

💎개요

이번에는 소프트맥스함수를 구현하겠습니다.

머신러닝에서 다중 클래스를 분류할때, 임의의 클래스의 확률적 수치를 추정할때 사용하는 함수입니다.

💎환경

통합개발환경: DevCpp
언어: C++17
운영체제: Windows11 Home
컴파일러: g++

🔪Softmax function

💎이론

우선 Softmax function은 이런식으로 생겼습니다.

딱 보기에도 엄청 심플하게 생겼습니다.
쉽게 설명드리자면, { 1, 2, 3 }이러한 배열이 들어왔다고 했을때...

분모쪽은 e1+e2+e3e^{1} + e^{2} + e^{3} 이런식으로 됩니다.
분자쪽은 함수로부터 전달된 수로 결정되죠.

예를들어서 sj = 1이라면 분자는 e1e^{1} 가 됩니다.

바로 코드실습 가겠습니다.

💎소스코드

#include <iostream>
#include <cmath>
#include <vector>

// https://www.HostMath.com/Show.aspx?Code=f(sj)%20%3D%20%5Cfrac%7Be%5E%7Bsj%7D%7D%7B%5Csum_%7Bi%3D1%7D%5E%7Bm%7De%5E%7Bsi%7D%7D
template<typename dataType>
double softmax(std::vector<dataType>& arr, dataType sj){
	if(!std::any_of(arr.begin(), arr.end(), [&sj](dataType& j){ return j == sj; })) throw std::runtime_error("Invalid value");
	
	dataType maxElement = *std::max_element(arr.begin(), arr.end());
	double sum = 0.0;
	for(auto const& i : arr) sum += std::exp(i - maxElement);
	
	return (std::exp(sj - maxElement) / sum);
}

int main() {
	std::vector<double> arr = { 2, 3, 5 };
	
	auto v_2 = softmax<double>(arr, 2);
	auto v_3 = softmax<double>(arr, 3);
	auto v_5 = softmax<double>(arr, 5);
	auto v_all = softmax<double>(arr, 2) + softmax<double>(arr, 3) + softmax<double>(arr, 5);
	
	std::cout << "v_2: " << v_2 << '(' << double(v_2 * 100) << "%)\n";
	std::cout << "v_3: " << v_3 << '(' << double(v_3 * 100) << "%)\n";
	std::cout << "v_5: " << v_5 << '(' << double(v_5 * 100) << "%)\n";
	std::cout << "v_all: " << v_all << '(' << double(v_all * 100) << "%)";
}

우선 함수 실행전에 std::any_of 를 사용했습니다. 그 이유는 배열 arr 안에 없는 수는 처리하지 않기 위함이죠.
배열 arr 에 포함되지 않은 수를 sj 로 집어넣으면 절대적으로 제대로 된 수치를 얻어낼 수 없기때문입니다.
그래서 런타임에러 예외사항을 std::any_of 를 이용하여 집어낼 겁니다.

자...
여기서 의문점이 드시는분들이 있으실 겁니다.

"왜 배열중의 최댓값을 빼지??"

흠.. 결과적으로 말하면 오버플로우를 막기 위함입니다.
Softmax function 에 사용되는 자연상수e는 지수함수의 성질을 가집니다.
지수함수는 수가 크면 클수록 기하급수적으로 그래프가 상승하거나 감소합니다.
그래서 수가 조금만 커지면 프로그램은 오버플로우를 뱉어냅니다.

따라서 최댓값을 뺌으로써 큰 값이 들어와도 오버플로우를 방지할 수 있습니다.

크게 예외 사항이 있다고 한다면 배열 arr 에 들어오는 수간의 차이값이 크면은 가끔 오버플로우를 뱉어낼겁니다.

🔪마무리

💎느낀점

소프트맥스함수를 직접 구현해보니 킬링타임용으로 재미있었다.

💎마치며...

궁금한 부분있으면 댓글로 질문주세요..!
그럼 안녕~~!!~!~!

profile
C++ / Assembly / Python 언어를 다루고 있습니다!

0개의 댓글