C++: 컴파일 시간 조합 구현

이경헌·2023년 10월 21일

도입

C++23에 이르러 std::views::cartesian_product를 사용하여 중첩된 for 반복문을 가독성있게 바꿀 수 있습니다. 예를 들어, 다음 numschars를 순서대로 순회해가며 출력할 경우

  • 반복문을 사용하여 해결
std::array nums { 1, 2, 3, 4 };
std::string_view chars = "abcde"sv;

for (int n : nums){
    for (char c : chars){
        std::println("{} {}", n, c);
    }
}
  • cartesian_product를 사용하여 해결
std::array nums { 1, 2, 3, 4 };
std::string_view chars = "abcde"sv;

for (auto [n, c] : std::views::cartesian_product(nums, chars)){
    std::println("{} {}", n, c);
}

과 같이 하나의 순회로 해결할 수 있고, 중첩 개수가 늘어날수록 더욱 효과적입니다.

combinations 구현

구현

하지만 특정 O(n2)O(n^2) 알고리즘, 그 중에서도 nn개의 원소 중 중복 없이 kk쌍을 뽑는 조합 문제에서는 내부 반복문이 그 이전 반복문에 의존함에 따라 위 함수를 사용할 수 없습니다. 이러한 알고리즘에는 대표적으로 버블 정렬, Two-sum 문제, 충돌 물체 쌍 찾기 등이 있습니다.

다음 코드는 해당 문제를 해결하기 위해 작성된 코드입니다. Concept와 Range를 이용하므로, C++ 20 또는 그 이상의 버전을 요구합니다.

#include <ranges>
#include <functional>

#define FWD(x) std::forward<decltype(x)>(x)

template <std::size_t, typename U>
using pair_second_t = U;

template <typename Fn, std::size_t N, typename T, std::size_t... Is>
struct is_n_invocable : is_n_invocable<Fn, N - 1, T, Is..., sizeof...(Is)>::type { };

template <typename Fn, typename T, std::size_t... Is>
struct is_n_invocable<Fn, 0, T, Is...> : std::is_invocable<Fn, pair_second_t<Is, T>...> { };

/**
 * Helper concept, same to std::invocable<Fn, T, T, ..., T> (which T are repeated N times)
 * @tparam Fn Function type to invoke.
 * @tparam N Number of repetitions.
 * @tparam T Type to passed to \p Fn.
 */
template <typename Fn, std::size_t N, typename T>
concept n_invocable = is_n_invocable<Fn, N, T>::value;

template <std::size_t N>
struct combinations_fn{
    template <std::ranges::forward_range R, n_invocable<N, std::ranges::range_common_reference_t<R>> F, std::forward_iterator... Its>
    constexpr void operator()(R &&range, F &&f, Its... its) const{
        if constexpr (sizeof...(Its) == N){
            std::invoke(FWD(f), FWD(*its)...);
        }
        else{
            auto it = [&](){
                if constexpr (sizeof...(Its) == 0){
                    return std::begin(range);
                }
                else{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-value"
                    auto it = (its, ...);
#pragma clang diagnostic pop
                    return ++it;
                }
            }();
            auto end = std::end(range);

            for (; it != end; ++it){
                operator()(FWD(range), FWD(f), its..., it);
            }
        }
    }
};

/**
 * Invoke function with \p N sequentially chosen elements in the range.
 * @tparam N Number of element to choose. Should be <tt>0 ≤ N ≤ range.size()</tt>.
 * @example
 * @code
 * std::array nums { 1, 2, 3, 4 };
 * combinations<2>(nums, [](int x, int y){
 *     std::println("{} {}", x, y);
 * };
 *
 * // The output will be:
 * // 1 2
 * // 1 3
 * // 1 4
 * // 2 3
 * // 2 4
 * // 3 4
 * @endcode
 */
template <std::size_t N>
inline constexpr combinations_fn<N> combinations;

사용례

사용례는 다음과 같습니다.

int sum = 0;

// 1부터 100까지의 수 중 중복되지 않게 서로 다른 세 쌍의 수를 뽑아, 그 수의 합을 sum 변수에 누적함.
combinations<3>(std::views::iota(1, 101), [&](auto ...chosen){
    sum += (chosen + ...);
});

첫 번째 인자인 rangestd::ranges::forward_range를 만족하여야 합니다. 따라서, 다음과 같이 위 문제 중 3의 배수만을 뽑아 해결하는 것도 가능합니다 (rangebidirectional_range를 모델).

int sum = 0;

// 1부터 100까지의 수 중 3의 배수를 중복되지 않게 서로 다른 세 쌍의 수를 뽑아, 그 수의 합을 sum 변수에 누적함.
auto one_to_hundred_only_multiple_of_three = 
    std::views::iota(1, 101) 
    | std::views::filter([](int n) { return n % 3 == 0; });
combinations<3>(one_to_hundred_only_multiple_of_three, [&](auto ...chosen){
    sum += (chosen + ...);
});

성능 분석

해당 함수의 성능 분석을 위해 nanobench 라이브러리를 사용해 첫 번째 문제를 해결하는 시간을 측정하였습니다. 정확도를 향상시키기 위해 1부터 1000까지의 수에 대해 계산했습니다.

#define ANKERL_NANOBENCH_IMPLEMENT
#include "nanobench.h"

// combinations code implementation should be in here.

int main() {
    using namespace ankerl;

    nanobench::Bench().run("nested for loop", []{
        int sum = 0;
        for (int i = 1; i < 1001; ++i){
            for (int j = i + 1; j < 1001; ++j){
                for (int k = j + 1; k < 1001; ++k){
                    sum += i + j + k;
                }
            }
        }

        nanobench::doNotOptimizeAway(sum);
    });

    nanobench::Bench().run("combinations", []{
        int sum = 0;
        combinations<3>(std::views::iota(1, 1001), [&](auto ...chosen){
            sum += (chosen + ...);
        });

        nanobench::doNotOptimizeAway(sum);
    });
}

벤치마크 결과

ns/opop/serr%totalbenchmark
495,000.002,020.200.1%0.01nested for loop
192,608.405,191.882.2%0.01combinations

원인을 정확히 파악할 수 없으나, 자체 구현한 combinations가 약 2.5배 가량 빠른 것으로 분석되었습니다. 빌드 환경은 M1 Pro, 컴파일러 Clang 17이며 -O3 -DNDEBUG로 릴리즈 빌드하였습니다. 해당 코드를 사용하셔도 좋을 것 같습니다.

profile
Undergraduate student in Korea University. Major in electrical engineering and computer science.

0개의 댓글