크리스마스 트리

Wonseok Lee·2021년 8월 23일
0

Beakjoon Online Judge

목록 보기
36/117
post-thumbnail

Problem link: https://www.acmicpc.net/problem/1234

문제를 보자마자 풀이는 바로 떠오른다.

아래와 같이 재귀 함수를 떠올리면 풀이가 끝난다.

  • Solve(current_level, current_r, current_g, current_b): 현재 current_level을 색칠해야하고, 남아있는 색이 current_{r|g|b}개 일 때, 전체 트리를 칠하는 경우의 수

현재 레벨에 칠해야할 색의 수는 current_level개 인데, 색 별로 수가 같아야 한다는 조건이 있으므로 아래 경우를 다 세어주면 된다.
(남아있는 장식품으로 칠할 수 있다는 가정하에, 각 경우는 같은 것이 있는 수열을 활용)

  • 1색으로 다 칠하는 경우(r|g|b)
  • 2색으로 다 칠하는 경우(rg|gb|rb)
  • 3색으로 다 칠하는 경우(rgb)

중간 중간 중복으로 세는 경우가 많으므로 DP를 써주면 더 좋은데, 안 써줘도 무리 없이 풀리는 입력의 크기를 가진 문제이다.

#include <iostream>
#include <cstdint>

using namespace std;

int64_t fact[11];

int64_t Solve(const int64_t level, const int64_t r, const int64_t g, const int64_t b, const int64_t n)
{
  if (level > n)
  {
    return 1;
  }

  int64_t ret = 0;

  if (level % 3 == 0 && r >= level / 3 && g >= level / 3 && b >= level / 3)
  {
    int64_t cnt = fact[level] / (fact[level / 3] * fact[level / 3] * fact[level / 3]);
    cnt *= Solve(level + 1, r - level / 3, g - level / 3, b - level / 3, n);
    ret += cnt;
  }

  if (level % 2 == 0)
  {
    if (r >= level / 2 && g >= level / 2)
    {
      int64_t cnt = fact[level] / (fact[level / 2] * fact[level / 2]);
      cnt *= Solve(level + 1, r - level / 2, g - level / 2, b, n);
      ret += cnt;
    }
    if (g >= level / 2 && b >= level / 2)
    {
      int64_t cnt = fact[level] / (fact[level / 2] * fact[level / 2]);
      cnt *= Solve(level + 1, r, g - level / 2, b - level / 2, n);
      ret += cnt;
    }
    if (r >= level / 2 && b >= level / 2)
    {
      int64_t cnt = fact[level] / (fact[level / 2] * fact[level / 2]);
      cnt *= Solve(level + 1, r - level / 2, g, b - level / 2, n);
      ret += cnt;
    }
  }

  if (r >= level)
  {
    ret += Solve(level + 1, r - level, g, b, n);
  }

  if (g >= level)
  {
    ret += Solve(level + 1, r, g - level, b, n);
  }

  if (b >= level)
  {
    ret += Solve(level + 1, r, g, b - level, n);
  }

  return ret;
}

int main(void)
{
  // Preprocess
  fact[0] = 1;
  for (int64_t it = 1; it < 11; ++it)
  {
    fact[it] = fact[it - 1] * it;
  }

  // For Faster IO
  ios_base::sync_with_stdio(false);
  cout.tie(nullptr);
  cin.tie(nullptr);

  // Read Inputs
  int64_t N, R, G, B;
  cin >> N >> R >> G >> B;

  // Solve
  cout << Solve(1, R, G, B, N) << '\n';

  return 0;
}
profile
Pseudo-worker

0개의 댓글