본문 바로가기

알고리즘&자료구조

[백준1208] 부분수열의 개수2

어떤 집합의 부분집합 중 그 원소의 합이 S인 부분집합의 개수를 구하는 문제이다.

문제상에는 수열, 부분수열이라 나와있지만 명칭의 직관성을 위해 집합, 부분집합이라고 하겠음

Idea1. Brute Force

#include <iostream>

using namespace std;

int n, s;
int arr[40];

int solution(int curr, int start, int end) {
    
    if (start > end) {
        if (curr == s) return 1;
        else return 0;
    } else {
        int ret = 0;
        ret += solution(curr + arr[start], start+1, end, isFront);
        ret += solution(curr, start+1, end, isFront);
        return ret;
    }
}

위 코드는 단순하게 모든 부분집합을 재귀적으로 조회하면서 원소의 합이 S인 경우를 추적한다. 원소가 n개 일때 부분집합의 개수는 2^n이므로, 모든 부분집합을 조회하는 것만으로도 **O(2^n)**의 시간복잡도가 발생한다. n이 최대 40이므로 2^40 = 약 10^12 이므로 1초의 제한에서는 시간초과가 발생한다.

 

Idea2. 나눠서 생각하기

 

Step1. 집합을 둘로 나눠 부분집합 조회

위 알고리즘에서 가장 큰 문제는 40개의 원소를 대상으로 모든 부분집합을 조회하는 것이다. 집합의 원소 개수를 줄인다면 충분히 모든 부분집합을 구할 수 있을것이다. 즉, 문제를 둘로 나누면 된다.

 

40개의 원소를 20개, 20개 나눠 각각 모든 부분집합을 조회한다면 각각 2^20의 부분집합이 존재하는 이는 약 10^6정도이므로 1초내에 계산하기 충분하다.

 

Step2. 두 집합의 부분집합을 합쳐 합이 S인 부분집합의 개수 구하기

이제 두개의 집합이 각각 모든 부분집합을 조회해 부분집합의 합을 구했다고 하자. 이제 진짜 답을 구해야한다. 두 집합의 부분집합을 합쳤을때 합이 S여야한다.

 

예를 들어 집합 P를 A와 B 둘로 나눴다고 해보자 그리고 A와 B의 부분집합을 각각 a, b라고 하자. 만약 어떤 부분집합 a의 합이 k라면 a와 결합해 합이 S인 부분집합의 개수는 합이 S-k인 모든 부분집합 b의 개수일 것이다. 그리고 이것을 모든 a에 대해 조회해 그 값을 모두 더하면 된다.

 

Step3. 알맞은 자료구조 정하기

step1. 에서 둘로 나눠진 집합의 부분집합들을 모두 조회해 그 합을 미리 구해놔야한다. 그렇다면 이값을 어디에다 저장해야 할까? 우리가 구하고 싶은것을 부분집합의 개수이다. 즉, step2에서 문제를 풀기 위해서는 어떤 합 k를 가진 모든 부분집합의 개수를 알아야 한다. 즉, key가 원소의 합이고 value가 그 합을 가지는 부분집합의 개수은 Map 자료형을 사용하는 것이 합리적이다.

정리

자 이제 과정을 정리해보자

 

  1. 집합을 둘로 나눠(A, B) 각각 모든 부분집합을 조회
  2. 부분집합의 합을 구하여 Map(원소의 합, 부분집합의 개수) 형태로 데이터를 저장
  3. A의 부분집합중 합이 k인 부분집합의 개수와 B의 부분집합 중 합이 S-k인 부분집합의 개수를 곱합. 이를 가능한 모든합계에 대해 진행

 

Solution

#include <iostream>
#include <unordered_map>

using namespace std;

using ll = long long;

int arr[40];
// start ~ mid
unordered_map<int, ll> m1;
// mid+1 ~ end
unordered_map<int, ll> m2;

int min_s = 40000001;
int max_s = -40000001;

void solution(int curr, int start, int end, bool isFront) {
    
    if (start > end) {
        if (max_s < curr) max_s = curr;
        if (min_s > curr) min_s = curr;
        if (isFront) {
            if (m1.find(curr) == m1.end()) {
                m1.insert({curr, 0});
            }
            m1[curr]++;
        } else {
            if (m2.find(curr) == m2.end()) {
                m2.insert({curr,0});
            }
            m2[curr]++;
        }
    } else {
        solution(curr + arr[start], start+1, end, isFront);
        solution(curr, start+1, end, isFront);
    }
}

int main() {
    
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    cout.tie(NULL);
    
    int n, s;
    cin >> n >> s;
    
    for (int i=0; i<n; ++i) {
        cin >> arr[i];
        // arr[i] = 0;
    }
    
    int mid = (n-1) / 2;
    solution(0, 0, mid, true);
    solution(0, mid+1, n-1, false);
    
    ll ret = 0;
    
    for (int i=min_s; i<=max_s; ++i) {
        int front = i;
        int back = s - i;
        
        if (m1.find(front) == m1.end()) continue;
        if (m2.find(back) == m2.end()) continue;
        
        ret += m1[front] * m2[back];
    }
    
    if (s == 0) ret--;
    cout << ret;
    
    
}
  • solution은 부분집합의 원소합을 구해 map에 저장한다.
    • isFront는 현재 구하는 부분집합의 합이 어떤 집합의 부분집합인지를 판단한다.
      • True : 앞쪽 부분집합(0 ~ mid) / False : 뒷쪽 부분집합 (mid+1 ~ end)
    • start가 end보다 크면 부분집합 조회가 끝난 것이므로 부분집합의 원소 합을 구해 해당 합계에 대응하는 value를 1증가
  • solution 함수를 통해 map의 값을 모두 채웠다면, 가능한 모든 원소합계(min_s ~ max_s)까지 순회하면 (앞의 집합에서 원소합이 k인 부분집합의 개수) * (뒤의 집합에서 원소합이 S-k인 부분집합의 개수)를 구해서 모두 더한다.