알고리즘 풀이/백준

[백준 1197] 최소 스패닝 트리

mhko411 2021. 4. 21. 20:16
728x90

문제

그래프가 주어졌을 때 그 그래프의 최소 스패팅 트리를 구하라.

 

입력

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이 가중치 C인 간선으로 연결되어 있다는 의미이다. C는 음수일 수도 있으며, 절댓값이 1,000,000을 넘지 않는다.

그래프의 정점은 1번부터 V번까지 번호가 매겨져 있고, 임의의 두 정점 사이에 경로가 있다. 최소 스패닝 트리의 가중치가 -2,147,483,648보다 크거나 같고, 2,147,483,647보다 작거나 같은 데이터만 입력으로 주어진다.

 

출력

첫째 줄에 최소 스패닝 트리의 가중치를 출력한다.


접근

최소 스패닝 트리를 공부하면서 최소 스패닝 트리를 찾는 알고리즘 중 크루스칼과 프림을 공부하기위해 이 문제를 풀었다.

 

먼저 스패닝 트리는 모든 노드가 연결된 그래프에서 싸이클이 없는 부분 그래프를 의미한다. 어떠한 그래프 내에서 다양한 스패닝 트리가 생성될 수 있으며, 스패닝 트리는 싸이클 구조를 갖지않고 (노드의 개수 - 1)개의 간선으로 이루어져있다.

 

여기서 최소 스패닝 트리는 스패닝 트리 중에 간선의 가중치 합이 가장 적은 스패닝 트리를 의미한다.

이 문제에서는 간선의 가중치 합이 가장 적은 스패닝 트리를 찾기위해 크루스칼 알고리즘과 프림 알고리즘을 적용시켰다.

 

구현

크루스칼 알고리즘

크루스칼 알고리즘은 A노드에서 B노드로 가는 무방향 간선이 C라는 비용을 갖는다고 할 때 비용을 오름차순으로 정렬하여 가장 적은 비용의 경로부터 선택하는 것이다. 

스패닝 트리는 노드의 개수보다 한 개 적은 간선을 갖고있기 때문에 해당 개수만큼의 경로를 선택하면 된다.

하지만 스패닝 트리는 싸이클이 없기 때문에 현재 어떠한 경로를 선택할 때 싸이클이 되는지 판단을 해야한다.

이때 사용하는 것이 각 노드들의 최상위 노드가 일치하는지 판단을 한다.

 

아래의 코드는 각 노드들의 최상위 노드를 찾는 과정을 나타낸다.

먼저 처음에 각 노드는 자기자신을 가리키게된다. 이어서 A에서 B의 경로를 선택하면 B의 부모노드가 A가 되도록 union과정을 거친다. 두 개의 부모노드가 다르다면 B의 부모노드를 A로 결정한다.

또한 노드들의 부모노드를 찾기위해 find_set()과정을 거친다. 만약 자기자신을 가리킨다면 그대로 반환을 하면되지만,

그렇지 않을 경우 입력받은 A노드의 최상위 노드를 찾아 A에 저장한다.

이 과정을 통해 싸이클 구조를 갖는지 판단할 수 있다.

parent = [n for n in range(V+1)]

def find_set(a):
    if parent[a] == a:
        return a
    else:
        b = find_set(parent[a])
        parent[a] = b
        return b
        
def union(a, b):
    a = find_set(a)
    b = find_set(b)
    if a != b:
        parent[b] = a

 

이제 그래프에 대한 정보를 입력받는다. 2차원 리스트로 그래프를 표현할 필요없이 A에서 B로 C만큼의 비용이 든다는 정보만 저장해둔다. 

이후 비용을 기준으로 오름차순으로 정렬한다.

V, E = map(int, input().split())
graph = []
for _ in range(E):
    a, b, c = map(int, input().split())
    graph.append((a, b, c))

graph = sorted(graph, key=lambda x:x[2])

 

이제 V-1개의 간선을 비용이 적은 경우부터 탐색을 한다.

만약 A와 B의 최상위 노드가 다르면 싸이클이 되지않기 때문에 선택을 하고 B의 최상위 노드를 A로 설정한다.

이 과정에서 선택한 간선의 개수를 카운트하여 V-1이 될 때 탐색을 종료하고 지금까지의 더한 비용을 출력한다.

answer = 0
edge_count = 0
for g in graph:
    a = g[0]
    b = g[1]
    d = g[2]
    if find_set(a) != find_set(b):
        union(a, b)
        answer += d
        edge_count += 1
    if edge_count == V-1:
        break
print(answer)

 

전체 코드

import sys
input = sys.stdin.readline

def find_set(a):
    if parent[a] == a:
        return a
    else:
        b = find_set(parent[a])
        parent[a] = b
        return b

def union(a, b):
    a = find_set(a)
    b = find_set(b)
    if a != b:
        parent[b] = a

V, E = map(int, input().split())
graph = []
for _ in range(E):
    a, b, c = map(int, input().split())
    graph.append((a, b, c))

parent = [n for n in range(V+1)]

graph = sorted(graph, key=lambda x:x[2])
answer = 0
edge_count = 0
for g in graph:
    a = g[0]
    b = g[1]
    d = g[2]
    if find_set(a) != find_set(b):
        union(a, b)
        answer += d
        edge_count += 1
    if edge_count == V-1:
        break
print(answer)

프림 알고리즘

프림 알고리즘은 특정 노드부터 시작해서 인접 노드 중에서 가장 적은 비용의 노드를 선택해나간다.

코드를 통해 살펴보자.

 

인접리스트로 그래프로 만든다.

각 노드의 인접한 노드와 비용을 함께 저장한다.

V, E = map(int, input().split())
graph = [[] for _ in range(V+1)]

for _ in range(E):
    a, b, c = map(int, input().split())
    graph[a].append((b, c))
    graph[b].append((a, c))

 

우선순위 큐를 활용해서 가정 적은 비용을 먼저 빼올 수 있도록한다.

또한 방문표시를 해서 두 번 이상 방문하지않아 싸이클 구조를 갖지않도록 한다.

우선순위 큐에서 데이터를 빼와서 인접한 노드 중에서 가장 적은 비용을 갖는 노드를 꺼낸다. 이제 이 노드에서 인접한 노드를 탐색하여 방문하지않은 노드를 다시 우선순위 큐에 넣고 위의 과정을 반복한다.

 

프림 알고리즘은 좀 더 다뤄봐야겠다.

def prim():
    result = 0
    visited = [False] * (V+1)
    hq = []
    heapq.heappush(hq, (0, 1)) # 비용 - 노드번호
    while hq:
        # 0: 비용, 1:노드번호
        # 아직 선택하지않은 인접노드 중 비용이 가장 적은 노드를 선택
        cur = heapq.heappop(hq)

        if visited[cur[1]]:
            continue
        visited[cur[1]] = True

        result += cur[0]
        for g in graph[cur[1]]:
            if not visited[g[0]]:
                heapq.heappush(hq, (g[1], g[0]))

전체 코드

import sys
import heapq
from _collections import deque
input = sys.stdin.readline

def prim():
    result = 0
    visited = [False] * (V+1)
    hq = []
    heapq.heappush(hq, (0, 1))
    while hq:
        cur = heapq.heappop(hq)

        if visited[cur[1]]:
            continue
        visited[cur[1]] = True

        result += cur[0]
        for g in graph[cur[1]]:
            if not visited[g[0]]:
                heapq.heappush(hq, (g[1], g[0]))

    return result

V, E = map(int, input().split())
graph = [[] for _ in range(V+1)]

for _ in range(E):
    a, b, c = map(int, input().split())
    graph[a].append((b, c))
    graph[b].append((a, c))

answer = prim()
print(answer)

 

'알고리즘 풀이 > 백준' 카테고리의 다른 글

[백준 1717] 집합의 표현  (0) 2021.04.22
[백준 1753] 최단경로  (0) 2021.04.22
[백준 2583] 영역 구하기  (0) 2021.04.21
[백준 1987] 알파벳  (0) 2021.04.20
[백준 15658] 연산자 끼워넣기(2)  (0) 2021.04.19