백준 #2213 트리의 독립집합


BOJ #2213 - 트리의 독립집합

https://www.acmicpc.net/problem/2213

# 문제 분류

트리 DP

풀이 접근 방법


되게 어려운 문제.. 트리 DP를 처음 접했을 뿐만 아니라 발문이 사실 이해가 안됐다.

체크/언체크를 하는 이진수 찾기@2248 , 사회망 서비스@2533이랑 비슷한 트리 DP 형태.


#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <queue>
#include <utility>
#include <vector>

#define mp make_pair
#define pb push_back
#define X first
#define Y second
#define fup(i, a, b, c) for (int(i) = (a); (i) <= (b); (i) += (c))
#define fdn(i, a, b, c) for (int(i) = (a); (i) >= (b); (i) -= (c))
#define endl "\n"

using namespace std;

typedef long long ll;
typedef pair<ll, ll> pl;
typedef pair<int, int> pi;
typedef vector<int> vi;
typedef vector<pi> vii;
const ll MOD = 1e9 + 7;
const int stMAX = 1 << 18;
const int INF = 1e9;
int N, w, node[100010];
vi adj[100010], tree[100010], answer;
int dp[10010][2], check[10010];

void dfs(int curr, int prev) {
    for (int next : adj[curr]) {
        if (next != prev) {
            tree[curr].pb(next);
            dfs(next, curr);
        }
    }
}

int solve(int curr, bool tf) {
    int &ret = dp[curr][tf];
    if (ret != -1)
        return ret;

    if (tf) {
        int ans = 0;
        for (int next : tree[curr]) {
            ans += solve(next, false);
        }
        return ret = ans + node[curr];
    } else {
        int ans = 0;
        for (int next : tree[curr]) {
            int res1 = solve(next, false);
            int res2 = solve(next, true);
            if (res1 < res2) {
                check[next] = 1;
            }
            ans += max(res1, res2);
        }
        return ret = ans;
    }
}

void track(int curr, bool tf) {
    if (tf) {
        answer.pb(curr);
        for (int next : tree[curr]) {
            track(next, 0);
        }
    } else {
        for (int next : tree[curr]) {
            track(next, check[next]);
        }
    }
}

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    memset(dp, -1, sizeof(dp));

    cin >> N;

    fup(i, 1, N, 1) {
        cin >> node[i];
    }

    fup(i, 2, N, 1) {
        int a, b;
        cin >> a >> b;

        adj[a].pb(b);
        adj[b].pb(a);
    }

    dfs(1, 1);

    int solve1 = solve(1, 0);
    int solve2 = solve(1, 1);

    if (solve1 < solve2) {
        check[1] = 1;
    }

    cout << max(solve1, solve2) << endl;

    track(1, check[1]);

    sort(answer.begin(), answer.end());
    for (int p : answer) {
        cout << p << " ";
    }
}