Algorithm

[백준] 중앙 트리 7812번 (java) 트리에서 dp

salmon16 2025. 2. 23. 15:40

출처 : https://www.acmicpc.net/problem/7812

 

풀이 방법

처음엔 문제의 케이스가 여러 번 주어진다는 것을 보지 못하고 O(n^2)으로 풀이할 수 있어서 풀이했지만, 케이스가 여러 개 주어지므로 시간 초과가 발생한다.

 

즉 O(n)만에 풀이를 해야한다.

 

항상 트리 문제를 풀이할 땐, 문제가 풀리지 않는다면 노드를 기준이 아닌 엣지를 기준으로 생각해 보아야 한다.

A가 중앙일 때와 B가 중앙일 때를 비교해 보면 A와 B를 연결하는 엣지의 더해지는 횟수만 변경되고 나머지 엣지는 유지된다.

이 점에 초점을 맞추어서 고민을 해보았다.

 

정답은 서브트리와 관련이 있었다.

A에서 B로 이동할 때 A-B엣지가 더해지는 수는 (B 쪽 엣지를 끊고 난 후 A의 자식 수) - (A쪽 엣지를 끊고 난 후 B의 자식 수) 만큼 더해지게 된다.

 

이것을 구현하기 위해 고민해봐야 할 것은 반대쪽 노드와 연결된 엣지를 끊고 난 후의 자식의 개수를 어떻게 구할 것이냐가 문제이다.

이는 기준점을 잡으면 해결할 수 있다. 즉 A를 루트로 생각한 후 dfs를 수행하며 자식의 노드 수만 저장을 한다면 (A 쪽 엣지를 끊고 난 후 B의 자식 수)는 A를 기준으로 dfs를 탐색했으므로 구해진 B의 자식 수이고, (B 쪽 엣지를 끊고 난 후 A의 자식 수)는 n - B의 자식수를 해주면 쉽게 구할 수 있다.

 

이를 통해 점화식을 dp[next] = dp[cur] + childCnt[next] * edgeWeight - edgeWeight * (n - childCnt[next])로 구할 수 있다.

이를 구할 때도 A를 기준으로 dfs를 수행해야 한다.

 

import java.io.*;
import java.util.*;

public class P_7812 {

    static int n;
    static ArrayList<ArrayList<Edge>> edges;
    static long[] dp;
    static int[] childCnt = new int[10001];
    static boolean[] visited = new boolean[10001];
    static class Edge{
        int idx, weight;
        Edge(int idx, int weight){
            this.idx = idx;
            this.weight = weight;
        }
    }

    public static int dfsCnt(int idx) {
        // 자신을 포함함
        visited[idx] = true;
        int ret = 1; // 자신의 수

        for (int i = 0;i < edges.get(idx).size();i++) {
            int next = edges.get(idx).get(i).idx;
            if (visited[next]) continue;
            int cnt = dfsCnt(next);
            ret +=  cnt;
            dp[idx] += ((long) edges.get(idx).get(i).weight * cnt) + dp[next];
        }
        childCnt[idx] = ret;
        return ret;
    }

    public static void dfs(int idx) { // 0을 기준으로 한 weight를 모두 구함
        visited[idx] = true;

        for (int i = 0;i < edges.get(idx).size();i++) {
            int next = edges.get(idx).get(i).idx;
            int weight = edges.get(idx).get(i).weight;
            if (visited[next]) continue;
            dp[next] = dp[idx] - (long) weight * childCnt[next] + (long) weight * (n - childCnt[next]);
            dfs(next);
        }
        return ;
    }
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st;
        while(true) {
            n = Integer.parseInt(br.readLine());
            if (n == 0) break;
            dp = new long[n];
            int a, b, c;
            edges = new ArrayList<>();
            for (int i = 0; i < n; i++) {
                edges.add(new ArrayList<>());
                Arrays.fill(dp, 0);
            }
            for (int i = 0; i < n - 1; i++) { // 간선 양방향으로 입력 받기
                st = new StringTokenizer(br.readLine());
                a = Integer.parseInt(st.nextToken());
                b = Integer.parseInt(st.nextToken());
                c = Integer.parseInt(st.nextToken());
                edges.get(a).add(new Edge(b, c));
                edges.get(b).add(new Edge(a, c));
            }

            Arrays.fill(visited, false);
            dfsCnt(0);

            Arrays.fill(visited, false);
            dfs(0);
            long ans = Long.MAX_VALUE;
            for (int i = 0;i < n;i++) {
                ans = Math.min(ans, dp[i]);
            }
            bw.write(Long.toString(ans) + '\n');

        }
        bw.flush();
        bw.close();
    }

}