Algorithm

[백준] 통신망 분할 17398번 (c++) 순서 역으로 생각하기

salmon16 2025. 3. 2. 19:34

출처 : https://www.acmicpc.net/problem/17398

 

풀이 방법

풀이의 방법만 안다면 구현은 간단한 문제였다.

 

간선을 하나씩 끊어 가면서 이때 나누어진 두 집합의 노드의 수를 계산하는 문제이다.

만약 간선을 끊을 때마다, 끊어진 두 노드에서 bfs 또는 dfs를 수행해서 각 집합의 노드의 수를 계산하게 된다면 시간 복잡도가 초과된다.

그렇게 때문에 nlogn 이하의 시간 복잡도를 가지는 알고리즘을 생각해 내야 했다.

 

지금 까지 공부하며 시간 복잡도를 줄였던 방법에 대해 생각해 보았다. 

1. 노드기준이 아닌 엣지 기준으로 생각해 보기

2. 방향 간선에선 간선의 방향을 바꾸어 보기

3. 순서를 역으로 생각해 보기

 

등등 생각이 났지만, 이 문제에 적용할 수 있는 방법은 역으로 생각하기였다.

 

1. 먼저, 끊을 간선을 제외한 상태로 그래프를 만든다.

2. 이제 끊어진 간선을 하나씩 다시 연결하면서 노드 개수를 계산한다.

3. 간선을 추가하기 전의 두 집합 크기가, 해당 간선을 끊었을 때의 두 집합 크기가 된다.

 

이 방법을 쓰면 매번 BFS/DFS를 돌리지 않아도 각 집합의 크기를 빠르게 구할 수 있다.

 

#include <iostream>
#include <vector>

using namespace std;
vector<pair<int, int> > edge;
long long cnt[100001];
int visited[100001];
int parents[100001];
vector<int> orders;

int n, m, q;

int find(int a) {
    if (parents[a] == a) return a;

    parents[a] = find(parents[a]);
    return parents[a];
}

void union_node(int a, int b){
    int parents_a = find(a);
    int parents_b = find(b);

    if (parents_a == parents_b) return;
    if (parents_a < parents_b) {
        parents[parents_b] = parents_a;
        cnt[parents_a] += cnt[parents_b];
        cnt[parents_b] = 0;
    }

    else {
        parents[parents_a] = parents_b;
        cnt[parents_b] += cnt[parents_a];
        cnt[parents_a] = 0;
    }
}

int main() {

    cin >> n >> m >> q;    

    int a, b;
    for (int i = 0;i < n;i++) {
        parents[i] = i;
        cnt[i] = 1;
    }
    for (int i = 0;i < m;i++) {
        cin >> a >> b;
        edge.push_back(make_pair(a-1, b-1));
    }    
    for (int i = 0;i < q;i++) {
        cin >> a;
        visited[a-1] = 1; // 끊을 엣지
        orders.push_back(a-1);
    }

    for (int i = 0;i < m;i++) { //초기 다 묶기
        if (visited[i] == 1) continue;
        union_node(edge[i].first, edge[i].second);
    }
    long long ans = 0;    
    for (int i = q-1; i >= 0 ;i--) {
        int a = edge[orders[i]].first;
        int b = edge[orders[i]].second;        
        int parents_a = find(a);
        int parents_b = find(b);
        if (parents_a == parents_b) continue;
        ans += (cnt[parents_a] * cnt[parents_b]);
        union_node(a, b);        
    }

    cout << ans << endl;

    return 0;
}