DynamicCube.exe

12 object(s)
 

BOJ 2213 : 트리의 독립집합

https://www.acmicpc.net/problem/2213

이 문제는 Tree DP로 풀 수 있다.

DP를 정의하면

dp[i][b] = 루트를 i로 하는 서브트리에서 루트를 포함하는 여부(b)에 따른 최적해(트리의 최대 독립집합임).
  
b 는 0,1이고, 0이면 포함하지 않고, 1이면 포함한다.

이때 dp를 두 가지 경우로 나누어 풀 수 있다.

b = 1 :

  자식노드가 항상 b = 0 이여야 하므로 Σ(dp[(자식노드)][0])이 최적해이다.

b = 0 :

  각 자식노드가 b = 0 또는 b = 1 이므로, Σ(max(dp[(자식노드)][0], dp[(자식노드)][1]))이 최적해이다.

문제에 맞게 트리를 구축한 다음, dfs를 돌면서 dp값을 갱신해 준다.

이때, 트리의 루트에서 b가 0일때와 1일때를 같이 구해놓는다. 왜냐하면 루트에서 b가 0일때와 1일때에 따라 최적해가 결정되기 때문이다.

그중 최댓값이 해이다. 남은 문제는 독립집합을 출력하는 일만 남았다.

간단하게 b의 값을 각 노드마다 저장해주는 배열을 생성한 뒤, dp를 수행하면서 배열에 b값을 저장한다.

그리고 dfs를 통해 배열에서의 b값이 1이면 답을 인접 리스트(stl vector), 큐 등의 자료구조에 넣는다.

그리고 답을 출력해주면 남은 문제도 풀리게 된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#include <bits/stdc++.h>
#define MEM 100005
#define sanic ios_base::sync_with_stdio(0)
using namespace std;
int n;
int dp[MEM][2];
int cost[MEM];
int a[MEM];
vector<int> g[MEM];
vector<int> t[MEM];
vector<int> ans;
void dfs(int c, int p){
    for(int i=0; i<g[c].size(); i++){
        int nx = g[c][i];
        if(nx!=p){
            t[c].push_back(nx);
            dfs(nx, c);
        }
    }
}
int DP(int cur, int b){
    int& r=dp[cur][b];
    if(r!=-1return r;
    r = 0;
    if(b){
        for(int i=0; i<t[cur].size(); i++){
            int nx=t[cur][i];
            r += DP(nx, 0);
        }
        r+=cost[cur];
        return r;
    }
    else{
        for(int i=0; i<t[cur].size(); i++){
            int nx=t[cur][i];
            int r1 = DP(nx, 0);
            int r2 = DP(nx, 1);
            if(r1>r2) a[nx] = 0;
            else a[nx] = 1;
            r += max(r1, r2);
        }
        return r;
    }
}
void findd(int cur, int b){
    if(b){
        ans.push_back(cur);
        for(int i=0; i<t[cur].size(); i++){
            int nx=t[cur][i];
            findd(nx, 0);
        }
    }
    else{
        for(int i=0; i<t[cur].size(); i++){
            int nx=t[cur][i];
            findd(nx, a[nx]);
        }
    }
}
main()
{
    sanic;
    cin >> n;
    for(int i=1; i<=n; i++)
        cin >> cost[i];
    for(int i=1; i<n; i++){
        int q1, q2;
        cin >>  q1 >> q2;
        g[q1].push_back(q2);
        g[q2].push_back(q1);
    }
    dfs(10);
    memset(dp, -1sizeof(dp));
    int a1=DP(1,0);
    int a2=DP(1,1);
    if(a1<a2) a[1= 1;
    else a[1= 0;
    cout << max(a1, a2) << '\n';
    findd(1, a[1]);
    sort(ans.begin(), ans.end());
    for(auto l : ans)
        cout << l << ' ';
}
 
cs