Branchless Programing

Yoon Sunkue·2021년 9월 26일
1

asm 단계에서 jmp를 줄이고, cpu pipline의 branch predict miss 로 인한 pipline_stall 을 최소화. 할거없으면 기존에 만들어두신 프로그램에서 도전해보자.
간단하다, if, switch 를 최대한 없애버리면 된다.

예시 코드를 보기전에, 먼저 몇가지를 알아두셧으면 한다, msvc 의 switch 는 어느 정도의 case가 없으면 if_else의 반복으로 생성되고, 좀 많다 싶으면 테이블을 생성한다,
아래는 12개의 case 이고, 4개의 case 일때와 디스어셈블 결과가 다르게 나온다.
4개일때의 어셈은 이 링크에서 확인하시라.
https://stackoverflow.com/questions/69334314/pipeline-stall-optimize-no-branch-programing

예시는 이렿다 내 씨피유에서는 조금 더 얹어서 2배 빨라졌다.

#include <iostream>
#include <chrono>

using clk = std::chrono::high_resolution_clock;
using namespace std::chrono;
using namespace std::literals::string_view_literals;

namespace timer {
	static clk::time_point StopWatch;

	inline void start() {
		StopWatch = clk::now();
	}

	inline void end(const std::string_view mess = ""sv)
	{
		auto t = clk::now();
		std::cout << mess << " : " << duration_cast<milliseconds>(t - StopWatch) << '\n';
	}
}

// controll //
#define noBranch
#define noInline
// controll //


#ifdef noInline
#define INLINE __declspec(noinline)
#else 
#define INLINE 
#endif

class OBJ {
public:
	size_t x = 0;
	INLINE void f1() {
		x += 13;
	}
	INLINE void f2() {
		x += 23;
	}
	INLINE void f3() {
		x += 18;
	}
	INLINE void f4() {
		x += 15;
	}
	INLINE void f5() {
		x += 132;
	}
	INLINE void f6() {
		x += 232;
	}
	INLINE void f7() {
		x += 182;
	}
	INLINE void f8() {
		x += 152;
	}
	INLINE void f9() {
		x += 131;
	}
	INLINE void f10() {
		x += 231;
	}
	INLINE void f11() {
		x += 181;
	}
	INLINE void f12() {
		x += 151;
	}
};

int main()
{
	size_t sum = 0;
	std::string in;
	std::cin >> in;
	timer::start(); 
	for (size_t q = 0; q < 1'000'000; q++) {
		for (const auto i : in) {
			OBJ a;
#ifdef noBranch
			static decltype(&OBJ::f1) func[] = 
			{ &OBJ::f1, &OBJ::f2, &OBJ::f3, &OBJ::f4,
				&OBJ::f5, &OBJ::f6, &OBJ::f7, &OBJ::f8,
				&OBJ::f9, &OBJ::f10, &OBJ::f11, &OBJ::f12 };
			(a.*func[i - '0'])();
#else
			switch (i - '0')
			{
			case 0: a.f1(); break;
			case 1: a.f2(); break;
			case 2: a.f3(); break;
			case 3: a.f4(); break;
			case 4: a.f5(); break;
			case 5: a.f6(); break;
			case 6: a.f7(); break;
			case 7: a.f8(); break;
			case 8: a.f9(); break;
			case 9: a.f10(); break;
			case 10: a.f11(); break;
			case 11: a.f12(); break;
			}
#endif
			sum += a.x;
		}
	}
	std::cout << "sum" << sum << std::endl;
	timer::end();
}

input은 숫자 대충 때려 박아주면 된다.

121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827121327412057612094769276310256012974430815620927438017430126359871234981203498762156981734827

디스 어셈블리는 이렇게 나온다.

//asm
 			static decltype(&OBJ::f1) func[] = 
    86: 			{ &OBJ::f1, &OBJ::f2, &OBJ::f3, &OBJ::f4,
    87: 				&OBJ::f5, &OBJ::f6, &OBJ::f7, &OBJ::f8,
    88: 				&OBJ::f9, &OBJ::f10, &OBJ::f11, &OBJ::f12 };
    89: 			(a.*func[i - '0'])();
00007FF79B111429  movsx       rax,byte ptr [rbx]  
00007FF79B11142D  lea         rcx,[a]  
00007FF79B111432  call        qword ptr [r13+rax*8-180h]
//asm
switch (i - '0')
00007FF620461434  movsx       ecx,byte ptr [rax]  
00007FF620461437  add         ecx,0FFFFFFD0h  				// 1 !
00007FF62046143A  cmp         ecx,0Bh  					// 2 !
00007FF62046143D  ja          main+185h (07FF6204614D5h)  		// 3 !
00007FF620461443  movsxd      rcx,ecx  					// 4 !
00007FF620461446  mov         edx,dword ptr [r11+rcx*4+1614h]  		// 5 !
00007FF62046144E  add         rdx,r11  					// 6 !
00007FF620461451  jmp         rdx  					// 7 !
    92: 			{
    93: 			case 0: a.f1(); break;
00007FF620461453  lea         rcx,[rbp-30h]  
00007FF620461457  call        OBJ::f1 (07FF620461290h)  
00007FF62046145C  jmp         main+185h (07FF6204614D5h)  		// 8 ! 개의 asm 제거 됨.
    94: 			case 1: a.f2(); break;
00007FF62046145E  lea         rcx,[rbp-30h]  
00007FF620461462  call        OBJ::f2 (07FF6204612A0h)  
00007FF620461467  jmp         main+185h (07FF6204614D5h)  
    95: 			case 2: a.f3(); break;
00007FF620461469  lea         rcx,[rbp-30h]  
00007FF62046146D  call        OBJ::f3 (07FF6204612B0h)  
00007FF620461472  jmp         main+185h (07FF6204614D5h)  
    96: 			case 3: a.f4(); break;
00007FF620461474  lea         rcx,[rbp-30h]  
00007FF620461478  call        OBJ::f4 (07FF6204612C0h)  
00007FF62046147D  jmp         main+185h (07FF6204614D5h)  
    97: 			case 4: a.f5(); break;
00007FF62046147F  lea         rcx,[rbp-30h]  
00007FF620461483  call        OBJ::f5 (07FF6204612D0h)  
00007FF620461488  jmp         main+185h (07FF6204614D5h)  
    98: 			case 5: a.f6(); break;
00007FF62046148A  lea         rcx,[rbp-30h]  
00007FF62046148E  call        OBJ::f6 (07FF6204612E0h)  
00007FF620461493  jmp         main+185h (07FF6204614D5h)  
    99: 			case 6: a.f7(); break;
00007FF620461495  lea         rcx,[rbp-30h]  
00007FF620461499  call        OBJ::f7 (07FF6204612F0h)  
00007FF62046149E  jmp         main+185h (07FF6204614D5h)  
   100: 			case 7: a.f8(); break;
00007FF6204614A0  lea         rcx,[rbp-30h]  
00007FF6204614A4  call        OBJ::f8 (07FF620461300h)  
00007FF6204614A9  jmp         main+185h (07FF6204614D5h)  
   101: 			case 8: a.f9(); break;
00007FF6204614AB  lea         rcx,[rbp-30h]  
00007FF6204614AF  call        OBJ::f9 (07FF620461310h)  
00007FF6204614B4  jmp         main+185h (07FF6204614D5h)  
   102: 			case 9: a.f10(); break;
00007FF6204614B6  lea         rcx,[rbp-30h]  
00007FF6204614BA  call        OBJ::f10 (07FF620461320h)  
00007FF6204614BF  jmp         main+185h (07FF6204614D5h)  
   103: 			case 10: a.f11(); break;
00007FF6204614C1  lea         rcx,[rbp-30h]  
00007FF6204614C5  call        OBJ::f11 (07FF620461330h)  
00007FF6204614CA  jmp         main+185h (07FF6204614D5h)  
   104: 			case 11: a.f12(); break;
00007FF6204614CC  lea         rcx,[rbp-30h]  
00007FF6204614D0  call        OBJ::f12 (07FF620461340h)  
   105: 			}

정말 단순하게 함수 포인터 테이블만 써도 8개의 asm 명령어가 줄어들었다. branch 를 branchless 는 아니더라도 less_branch 하게 만들 수 있다.

완전한 branchless 예는 이렇다.

namespace optimize {
		// branchless
		template<integral _Ty> inline constexpr _Ty abs(const _Ty x) noexcept {
			const _Ty y{ x >> (sizeof(_Ty) * 8 - 1) };
			return (x ^ y) - y;
		}

		// branchless 
		inline constexpr char toupper_alphabets(char c) noexcept {
			return c -= 32 * ('a' <= c && c <= 'z');
		}

		// branchless 
		inline constexpr char toupper_alphabets(char c) noexcept {
			return c += 32 * ('A' <= c && c <= 'Z');
		}
	}

0개의 댓글