개발/컴퓨터과학

최소 스패닝 트리 (MST) + 백준 1197번

센솔 2024. 3. 11. 23:02

최소 스패닝 트리란?

 최소 스패닝 트리(Minimum spanning tree) 는 그래프 이론에서 자주 사용하는 개념이다.

최소 신장트리, MST라고도 부르는데 다 같은 말이다.

 

 

스패닝 트리 / 최소 스패닝 트리

'최소 스패닝 트리'를 살펴보기에 앞서, 먼저 '스패닝 트리'를 이해해보자.

스패닝 트리는 어떤 그래프에서 모든 정점을 포함하지만, 사이클이 발생하지 않는 트리다.

다음 그래프를 보면서 이해해보자.

 

위 그래프는 사이클이 있다. 1번, 2번, 3번 노드가 서로 사이클을 발생시킨다.

사이클을 없애면서 모든 노드를 포함하는 그래프는 다음과 같은 모양이 될 것이다.

 

위에 보이는 그래프들 모두  '스패닝 트리' 다. 사이클이 없고 모든 노드를 포함하기 때문이다.

이때 스패닝 트리의 '가중치'를 물어본다면, 모든 간선 비용의 합을 계산하면 된다.

따라서 1번 그래프의 가중치는 3+2+4= 9. 2번 그래프의 가중치는 1+3+4=8, 3번 그래프의 가중치는 1+2+4=7 된다. 

 

'최소 스패닝 트리' 는 만들 수 있는 모든 스패닝 트리 중 가장 가중치가 작은 트리를 말한다.

따라서 위 두 그래프의 경우 3번 그래프가 가중치가 제일 작으므로 최소 스패닝 트리가 된다.

 

Kruskal 알고리즘 - 아이디어

최소 스패닝 트리는 간선의 비용을 최소가 되게 하는 것이 핵심이다.

따라서 다음과 같은 과정을 거치면 최소 스패닝 트리를 구현할 수 있다.

  • 모든 간선을 비용에 따라 오름차순 정렬한다
  • 최소 비용의 간선부터 선택해 트리에 포함시킨다. 이때, 사이클이 발생하면 포함시키지 않는다.

위 그래프를 예시로 본다면 최소 스패닝 트리가 만들어지는 과정은 다음과 같을 것이다.

 

1단계)

비용이 가장 적은 1-2 를 잇는 간선이 트리에 포함된다.

 

2단계)

 그 다음으로 비용이 가장 적은 2-3을 잇는 간선이 트리에 포함된다. 

 

3단계)

 그 다음으로 비용이 가장 적은 간선은 1-3 을 잇는 간선이지만, 이 간선을 포함시키면 사이클이 발생한다. 

간선을 포함시키지 않고 건너뛴다

 

4단계)

 

마지막 간선까지 포함되며 최소 신장 트리가 완성된다.

 

Kruskal 알고리즘 - 코드 구현

 최소 스패닝 트리를 구현할 때 '사이클이 발생하는지' 를 본다고 했다.

눈으로는 이해가 되는데, 컴퓨터에게 이 그래프가 사이클을 포함하는지 알아내게 하려면 어떻게 해야 할까?

 

이때 쓰는 것이 바로 분리 집합, 'Union-Find' 다.

유니온 파인드는 '두 노드가 같은 집합에 속해있는지' 를 판별한다.

 

두 노드가 같은 집합에 속해있다는 것은 이미 서로가 연결되어 있다는 것을 의미한다.

이미 연결된 두 노드 사이에 또 다른 간선을 추가하는 것은 사이클을 생성하는 행위가 되므로, 이 경우 간선을 포함시키지 않음으로써 사이클의 형성을 방지할 수 있다.

 

Union-find 에 대한 설명은 이 포스팅의 범위를 벗어나는 것 같아서 별도로 다루지는 않겠다.

 

아래는 백준 1197번 문제 '최소 스패닝 트리' 를 풀이한 코드다.

 

1197번: 최소 스패닝 트리

첫째 줄에 정점의 개수 V(1 ≤ V ≤ 10,000)와 간선의 개수 E(1 ≤ E ≤ 100,000)가 주어진다. 다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 주어진다. 이는 A번 정점과 B번 정점이

www.acmicpc.net

 

 

아래 코드에 대한 설명:

첫째 줄에 정점의 개수 V와 간선의 개수 E가 입력되고,
다음 E개의 줄에는 각 간선에 대한 정보를 나타내는 세 정수 A, B, C가 입력된다고 생각했을 때

최소 신장트리의 가중치를 구하는 코드다

 

# 유니온파인드+Kruskal 알고리즘을 이용한 풀이

# 부모 찾기
def find_parent(x):
    if parent[x] != x:
        parent[x] = find_parent(parent[x])
    return parent[x]

# 부모 합치기
def union_parent(a, b):
    a = find_parent(a)
    b = find_parent(b)
    
    if a < b:
        parent[b] = a
    else:
        parent[a] = b
    
V, E = map(int, input().split())
parent = list(range(0, V+1, 1))
edges = []

for i in range(E):
    A, B, C = map(int, input().split())
    edges.append((C, A, B))

edges = sorted(edges, key = lambda x: x[0])
total_cost = 0

for i in range(len(edges)):
    cost, a, b = edges[i]
    
    if find_parent(a) != find_parent(b): # 부모가 다르면 사이클 발생 x, 신장트리에 포함시키기
        union_parent(a, b)
        total_cost += cost

print(total_cost)