부분수열의 합 2

Wonseok Lee·2021년 8월 23일
0

Beakjoon Online Judge

목록 보기
37/117
post-thumbnail

Problem link: https://www.acmicpc.net/problem/1208

처음에는 DP로 풀었고, 다른 풀이가 또 있나를 찾아보다가 투 포인터 풀이로도 풀어보았다(둘 다 무난하게 AC는 받는다).

DP풀이 같은 경우에는 아래와 같이 캐시/점화식을 정의하였다.

  • CACHE[i][s]: i번째 수까지 사용해서 s를 만들 수 있는 경우의 수
  • CACHE[i][s]: CACHE[i-1][s] + CACHE[i-1][s-arr[i]]

위의 풀이가 가능한 이유는 s가 비교적 작은 범위로(40개 수의 최대/최소 범위가 +-40*100000 밖에 되지 않는다.)

아래 사항 정도를 주의해주면 DP로 푸는데 큰 어려움이 없다.

  • 정답(경우의 수) 범위가 int를 넘어갈 수 있다.
  • 아무것도 고르지 않고 0을 만드는 경우의 수를 빼주어야 한다.

DP풀이의 속도가 조금 맘에 들지 않아서, 다른 풀이를 찾아보았는데 역시 더 좋은 풀이가 있었다.

일단, 주어진 숫자 배열을 2등분한다(좌/우라고 하자).

좌/우에 대해서 각각 나올 수 있는 모든 부분합을 구해서 저장해주자(최대 20개 씩이므로 큰 무리없이 구할 수 있다).

좌/우에 대해서 합 S를 찾는 투 포인터를 돌려주자.

주의할 점은, 처음에 구현을 간단히 하려고 map을 사용했는데, 이렇게하면 TLE가 뜬다.

좌표압축 문제들에서의 교훈과 동일하게, 되도록 flat한 map을 만들도록 하자.

DP Solution

#include <iostream>
#include <cstdint>

using namespace std;

static const int kMaxN = 40;
static const int kMaxI = 100000;
static const int kMaxSum = kMaxN * kMaxI;

inline int Sum2Idx(const int sum)
{
  return sum + kMaxSum;
}

int NUMS[kMaxN];
uint64_t CACHE[2][2 * kMaxSum + 1];

int main(void)
{
  // For Faster IO
  ios_base::sync_with_stdio(false);
  cout.tie(nullptr);
  cin.tie(nullptr);

  // Read Input
  int N, S;
  cin >> N >> S;

  for (int it = 0; it < N; ++it)
  {
    cin >> NUMS[it];
  }

  // Solve
  CACHE[0][Sum2Idx(0)] += 1;
  CACHE[0][Sum2Idx(NUMS[0])] += 1;
  for (int i = 1; i < N; ++i)
  {
    for (int s = -N * kMaxI; s <= N * kMaxI; ++s)
    {
      CACHE[i % 2][Sum2Idx(s)] = CACHE[(i - 1) % 2][Sum2Idx(s)];
      if (-N * kMaxI <= s - NUMS[i] && s - NUMS[i] <= N * kMaxI)
      {
        CACHE[i % 2][Sum2Idx(s)] += CACHE[(i - 1) % 2][Sum2Idx(s - NUMS[i])];
      }
    }
  }

  // Print answer
  cout << CACHE[(N - 1) % 2][Sum2Idx(S)] - (S == 0 ? 1 : 0) << '\n';

  return 0;
}

Two Pointers Solution

#include <iostream>
#include <cstdint>
#include <vector>
#include <algorithm>

using namespace std;

vector<int> NUMS;

int64_t Solve(const vector<int>& nums, const int target)
{
  size_t lower = nums.size() / 2;
  size_t upper = nums.size() - lower;

  vector<int> lower_map((size_t)1 << lower);
  vector<int> upper_map((size_t)1 << upper);

  // Lower
  for (size_t mask = 0; mask < ((size_t)1 << lower); ++mask)
  {
    int sum = 0;
    for (size_t idx = 0; idx < lower; ++idx)
    {
      if (((size_t)1 << idx) & mask)
      {
        sum += nums[idx];
      }
    }

    lower_map[mask] = sum;
  }

  // Upper
  for (size_t mask = 0; mask < ((size_t)1 << upper); ++mask)
  {
    int sum = 0;
    for (size_t idx = 0; idx < upper; ++idx)
    {
      if (((size_t)1 << idx) & mask)
      {
        sum += nums[lower + idx];
      }
    }

    upper_map[mask] = sum;
  }

  sort(lower_map.begin(), lower_map.end());
  sort(upper_map.begin(), upper_map.end());

  // Two pointers
  int64_t ans = 0;
  auto left = lower_map.begin();
  auto right = upper_map.rbegin();

  while (left != lower_map.end() && right != upper_map.rend())
  {
    int sum = *left + *right;
    if (sum == target)
    {
      int64_t left_cnt = 1;
      int64_t right_cnt = 1;

      int left_prev = *left;
      int right_prev = *right;

      ++left;
      ++right;

      while (left != lower_map.end() && *left == left_prev)
      {
        ++left_cnt;
        ++left;
      }

      while (right != upper_map.rend() && *right == right_prev)
      {
        ++right_cnt;
        ++right;
      }

      ans += left_cnt * right_cnt;
    }
    else if (sum < target)
    {
      ++left;
    }
    else
    {
      ++right;
    }
  }

  return ans;
}

int main(void)
{
  // For Faster IO
  ios_base::sync_with_stdio(false);
  cout.tie(nullptr);
  cin.tie(nullptr);

  // Read Input
  int N, S;
  cin >> N >> S;

  NUMS.assign(N, 0);
  for (int it = 0; it < N; ++it)
  {
    cin >> NUMS[it];
  }

  // Solve
  cout << Solve(NUMS, S) - (S == 0 ? 1 : 0) << '\n';

  return 0;
}
profile
Pseudo-worker

0개의 댓글