본문 바로가기

Algorithm/백준

[파이썬, python] 백준 1967 - 트리의 지름

반응형

트리의 지름 문제

 

dp를 사용해서 문제를 풀었는데 다른 사람들은 bfs를 사용해 풀어 정리하게 되었다. 트리의 지름에 대한 개념이 있어야 풀 수 있는 문제인 것 같다.

 

핵심 아이디어: 트리에선, 임의의 노드에서 가장 먼 점이 가장 거리가 먼 두 점(지름의 양 끝) 중 하나의 점

 

많은 사람들이 푼 방식은 bfs를 두번 사용해주는 것이었다.

1. bfs를 통해 임의의 노드로부터 가장 먼 노드A를 구해준다. => 이때의 노드 A는 트리 지름의 한 점이 된다.

2. bfs를 통해 노드A로부터 가장 먼 노드B를 구해준다. => 이때의 노드 B는 트리의 나머지 한 점이 된다.

 

중요한 아이디어는 1번의 " 임의의 노드에서 가장 먼 점이 가장 거리가 먼 두 점 중 하나의 점"이라는 것이다. 

이에 대해 귀류법으로 증명한 블로그가 있어 첨부한다.

https://blog.myungwoo.kr/112

 

트리의 지름 구하기

트리에서 지름이란, 가장 먼 두 정점 사이의 거리 혹은 가장 먼 두 정점을 연결하는 경로를 의미한다. 선형 시간안에 트리에서 지름을 구하는 방법은 다음과 같다: 1. 트리에서 임의의 정점 $x$를

blog.myungwoo.kr

이 아이디어를 가지고 코드를 짜면 아래와 같다.

# BFS
from collections import deque
max_node = -1
def bfs(start_node):
    global n, max_node
    visited =[False]*(n+1)
    que = deque([[start_node,0]])
    visited[start_node]=True
    max_dist = 0

    while que:
        now, now_dist = que.popleft()
        for child,child_dist in data[now]:
            if not visited[child]:
                visited[child]=True
                que.append([child,now_dist+child_dist])
                if max_dist < now_dist+child_dist:
                    max_dist = now_dist+child_dist
                    max_node = child
    return max_dist
                
n = int(input())
if n == 1:
    print(0)
else:
    data = [[] for _ in range(n+1)]
    for _ in range(n-1):
        # 부모, 자식, 거리
        a,b,c = map(int,input().split())
        data[a].append([b,c])
        data[b].append([a,c])
    bfs(1)
    print(bfs(max_node))

 

cf)

내가 푼 방식은, bfs를 통해 각 노드의 level(높이)를 구해주고, 가장 높은 노드( 말단 노드)부터 최대 거리를 저장하는 dp 방식을 사용하였다. 이 방법은 O(n)< 그 사이 어딘가 < O(n^2)이다.

import sys
from collections import deque
input = sys.stdin.readline
n = int(input())
if n == 1:
    print(0)
else:
    data = [[] for _ in range(n+1)]
    for _ in range(n-1):
        # 부모, 자식, 거리
        a,b,c = map(int,input().split())
        data[a].append([b,c])
    level = []
    que = deque([[1,1]])
    # node , level
    level.append([1,1])
    # 각 노드의 level을 구해줌. (BFS)
    while que:
        node,now_level =que.popleft()
        for child,_ in data[node]:
            level.append([child,now_level+1])
            que.append([child,now_level+1])
    level.sort(key=lambda x:-x[1])
    # print(max_level)
    # 각 노드에 대한 현 노드 ~ 말단 노드(leaf) 까지의 최대 거리 구함.
    # max_length 구함.(지름)
    dp = [0]*(n+1)
    max_length = 0
    for node,node_level in level:
        temp = []
        if len(data[node])==0:
            continue
        for child,child_length in data[node]:
            temp.append(dp[child]+child_length)
            # 만약 자식 노드가 두개 이상이라면 가장 긴 두 지름을 구해서 더해보고 최대 지름을 구한다.
            if len(temp)>=2:
                temp.sort(reverse=True)
                a,b=temp[0],temp[1]
                max_length = max(max_length,a+b)
        dp[node]=temp[0]
        # 만약 자식 노드가 한 줄기라면, 현재 최대 지름과 말단 ~ 현 노드의 길이를 비교한다.
        if len(temp)==1:
            max_length = max(max_length,dp[node])
    # if max_length == 0:
    #     print(dp[1])
    print(max_length)
반응형