문제
트리의 지름은 임의의 두 정점 사이의 거리 중 가장 긴 것을 말한다.
트리의 정보가 주어졌을 때 트리의 지름을 구하라
입력
트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지 매겨져 있다.
먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다.
출력
첫째 줄에 트리의 지름을 출력한다.
접근
자료구조 트리에 대해 공부하기위해 이 문제를 풀었다.
일단 트리는 그래프와 다르게 순환하지 않는 자료구조고 어떤 노드가 루트인가에 따라 모양이 달라진다.
지금까지 트리의 지름을 구하기위한 방법을 이해한 것은 다음과 같다.
트리의 지름을 구하기 위해서는 두 개의 리프노드를 선택해서 그 사이의 거리를 구한다. 그렇다면 리프노드를 먼저 구해야하는데 루트노드에 따라서 리프노드가 달라진다.
그렇다면 모든 루트노드에 대해 탐색을 해서 거리가 가장 먼 리프노드를 구해야할까?
지금까지 이해한 것으로는 아니라는 것이다.
다른 사람들의 풀이를 보면 임의의 노드를 루트 노드로하여 가장 먼 리프노드를 구하고 해당 리프노드를 다시 루트노드로하여 가장 먼 노드 하나를 구한다는 것이다.
정리한다면 임의의 노드를 시작으로하여 리프노드를 구한다면 트리의 지름을 구성하는 두 개의 정점 중 하나가 될 것이고 이 리프노드를 시작으로 탐색하여 지름을 구한다.
구현
먼저 입력되는 트리의 정보를 통해 트리를 만들어준다.
N = int(input())
tree = [[] for _ in range(N)]
for _ in range(N):
tree_info = list(map(int, input().split()))
for i in range(1, len(tree_info)-1, 2):
tree[tree_info[0]-1].append((tree_info[i]-1, tree_info[i+1]))
그리고 임의의 노드를 시작으로 가장 먼 리프노드 하나를 구하고 이 리프노드를 시작으로 트리의 지름을 구한다.
dist, node = bfs(0)
dist, node = bfs(node)
BFS를 통해 루트 노드부터 탐색을 진행한다.
visited에는 부모노드부터 자식노드까지의 거리를 담는다. 아직 방문하지 않은 노드라면 지금까지의 거리를 추가해주고 최대거리를 찾는다. 최종적으로 최대거리와 최대거리에 있는 노드를 반환한다.
def bfs(start_node):
visited = [-1] * N
visited[start_node] = 0
far_dist_node = [0, 0]
q = deque()
q.append(start_node)
while q:
parent = q.popleft()
for child, dist in tree[parent]:
if visited[child] == -1:
q.append(child)
visited[child] = visited[parent] + dist
if far_dist_node[0] < visited[child]:
far_dist_node[0] = visited[child]
far_dist_node[1] = child
전체 코드
import sys
from _collections import deque
input = sys.stdin.readline
def bfs(start_node):
visited = [-1] * N
visited[start_node] = 0
far_dist_node = [0, 0]
q = deque()
q.append(start_node)
while q:
parent = q.popleft()
for child, dist in tree[parent]:
if visited[child] == -1:
q.append(child)
visited[child] = visited[parent] + dist
if far_dist_node[0] < visited[child]:
far_dist_node[0] = visited[child]
far_dist_node[1] = child
return far_dist_node
N = int(input())
tree = [[] for _ in range(N)]
for _ in range(N):
tree_info = list(map(int, input().split()))
for i in range(1, len(tree_info)-1, 2):
tree[tree_info[0]-1].append((tree_info[i]-1, tree_info[i+1]))
dist, node = bfs(0)
dist, node = bfs(node)
print(dist)
'알고리즘 풀이 > 백준' 카테고리의 다른 글
[백준 1068] 트리 (0) | 2021.04.04 |
---|---|
[백준 1967] 트리의 지름 (0) | 2021.04.04 |
[백준 1477] 휴게소 세우기 (0) | 2021.04.01 |
[백준 9465] 스티커 (0) | 2021.03.30 |
[백준 2110] 공유기 설치 (0) | 2021.03.29 |