나만의 작은 도서관

문제 1967. 트리의 지름(에센셜 4) 본문

백준 문제풀이/Graph Theory

문제 1967. 트리의 지름(에센셜 4)

pledge24 2024. 5. 6. 21:54

문제 링크

https://www.acmicpc.net/problem/1967

 

난이도 : 골드 4

 

문제 요약 설명

노드 개수가 n개인 트리가 하나 있다. 각 노드는 1부터 n까지 번호가 매겨져 있을 때, 트리의 지름을 구하시오. 지름은 노드 사이의 최대 거리를 의미한다.

입력

첫 번째 줄: 노드의 개수

둘째 줄부터 n-1개의 줄: 부모 노드의 번호, 자식 노드의 번호, 간선의 가중치

 

입력 제한

노드의 개수 n(1 ≤ n ≤ 10,000)

0 < 간선의 가중치 < 100 

입력 예제

// input
12
1 2 3
1 3 2
2 4 5
3 5 11
3 6 9
4 7 1
4 8 7
5 9 15
5 10 4
6 11 6
6 12 10

// ans
45

 

풀이 방식

 

임의의 노드를 루트로 하는 트리에서 노드 사이의 최대 거리를 M이라 했을 때, 트리의 지름은 가장 큰 M을 의미한다. 그렇기 때문에 모든 노드에 대해 루트로 지정하고 M을 구하는 작업을 해 가장 큰 M을 찾으면 된다. 문제는 해당 작업을 각 노드마다 따로하면 연산량이 비효율적으로 많다는 문제가 생긴다. 이를 해결하기 위해 dfs를 이용해 단 1번의 완전 탐색으로 모든 M을 구한다. 

 

정답 코드 

더보기
#include <bits/stdc++.h>

#define fastio cin.tie(0)->sync_with_stdio(0)

using namespace std;

struct weight_data
{
    int node_no;
    int weight;
};

vector<vector<weight_data>> nodes;

int N; 
int max_length = -1;

// 현재 노드를 루트로 하는 서브트리의 지름을 구해 가장 긴 지름과 비교 및 저장 후,
// 현재 노드에서 가장 긴 길이를 반환한다.
int dfs(int node_no){
  
    if(nodes[node_no].size() == 0) return 0;

    vector<int> length;
    
    for(weight_data next_node : nodes[node_no]){
        length.push_back(dfs(next_node.node_no) + next_node.weight);
    }  
    
    // 내림차순 정렬.
    sort(length.begin(), length.end(), greater<int>()); 

    if(length.size() >= 2){     
        max_length = max(max_length, length[0]+length[1]);
    }
   
    return length[0];
}

int main() {
	fastio;
    cin >> N;

    nodes.resize(N+1);
   
    // input data.
    int node1, node2, weight;
    for(int i = 0; i < N-1; i++){
        cin >> node1 >> node2 >> weight;
        nodes[node1].push_back({node2, weight});
    }
   
    // 노드 개수가 1인경우, 정답은 0.
    if(N == 1){
        cout << 0 << '\n';
        return 0;
    }

    vector<int> length;

    // 각각의 자식노드 서브트리에서 가장 긴 길이를 가져온다.
    // dfs를 통해 자식노드에서도 반복한다.
    for(weight_data next_node : nodes[1]){
        length.push_back(dfs(next_node.node_no) + next_node.weight);
    }

    // 내림차순 정렬.
    sort(length.begin(), length.end(), greater<int>());     

    // 지름의 길이가 저장된 가장 긴 길이랑 비교해 더 긴 길이를 저장한다.
    // 가져온 길이 중 1번째, 2번째로 긴 길이의 합을 지름의 길이로 결정한다.
    // 자식노드가 1개인 경우, 1번째 길이를 지름의 길이로 결정한다.
    if(length.size() >= 2){
        max_length = max(max_length, length[0]+length[1]);
    }
    else{
        max_length = max(max_length, length[0]);
    }


    cout << max_length << '\n';
}

 

더 좋은 코드?