본문 바로가기
알고리즘

최소 연결 그래프 - Minimum Spanning Tree (MST)

by Marco Backman 2025. 8. 19.

만약 노드간 연결된 간선들이 비용이 있을 때, 모든 간선의 비용이 가장 적은 경로로 모든 노드를 연결하는 방법입니다.

 

MST를 구하기 위해 다음 두가지의 알고리즘이 주로 사용됩니다.

 

Kruskal Algorithm (경계 기반 Greedy 접근법) - (Edge 의 비용을 기점으로 선택)

항상 전체적으로 가장 낮은 경계값 선택하는 방식입니다.

만약 각 노드와 간선간의 weight가 주어진다면 우선 weight가 가장 낮은 간선을 선택하여 노드들을 선택합니다. 이때 선택된 간선은 사이클을 형성하면 안됩니다.

 

모든 노드를 다 선택할 때까지 비용이 낮은 간선부터 선택하여 연결하다보면 결국에는 모든 노드를 연결하게 되고 이때 모든 간선의 비용을 더하면 됩니다.

https://youtu.be/ivcbaIhrcsE

 

 

class UnionFind:
    """A helper class for the Disjoint Set Union (DSU) data structure."""
    def __init__(self, n):
        # Initially, each node is its own parent.
        self.parent = list(range(n))
        self.rank = [1] * n

    def find(self, i):
        # Find the root of the set containing element i
        if self.parent[i] == i:
            return i
        # Path compression for efficiency
        self.parent[i] = self.find(self.parent[i])
        return self.parent[i]

    def union(self, i, j):
        # Union by rank for efficiency
        root_i = self.find(i)
        root_j = self.find(j)
        if root_i != root_j:
            if self.rank[root_i] > self.rank[root_j]:
                self.parent[root_j] = root_i
            elif self.rank[root_i] < self.rank[root_j]:
                self.parent[root_i] = root_j
            else:
                self.parent[root_j] = root_i
                self.rank[root_i] += 1
            return True # Return True if a merge happened
        return False # Return False if they were already in the same set

def find_min_cost(num_servers, connections):

    total_cost = 0
    edges_count = 0

    connections = sorted(connections, key = lambda item: item[2])
    unionFind = UnionFind(num_servers)

    for server1, server2, cost in connections:
        #connection doesn't form a cycle
        if unionFind.union(server1, server2):
            total_cost += cost
            edges_count += 1
  
    if edges_count == num_servers - 1:
        return total_cost
    else:
        return -1
    

if __name__ == "__main__":
    num_servers = 5
    #server 1, server 2, cost
    connections = [
        [0, 1, 10],
        [0, 2, 6],
        [0, 3, 5],
        [1, 3, 15],
        [2, 3, 4],
        [3, 4, 2]
    ]

    min_cost = find_min_cost(num_servers, connections) #expect 21
    print(f"The minimum cost to connect all servers is: {min_cost}")

 

Prim MST Algorithm (벡터 기반 Greedy 접근법) - (노드의 선택 가능한 edge 중 비용이 가장 낮은 edge선택)

 

https://youtu.be/cplfcGZmX7I

 

 

Kruskal Algorithm 과 반대로 edge를 선택하는 것이 아닌 vertex를 선택하여 탐색하는 기법이다.

 

이 또한 사이클이 있으면 안돼며 가장 비용이 낮은 edge를 선택하여 모든 노드를 순회한다.

 

import heapq
import collections

def prims_mst(num_vertices: int, edges: list[list[int]]):
    """
    Calculates the Minimum Spanning Tree (MST) using Prim's algorithm.

    Args:
        num_vertices: The total number of vertices in the graph, labeled 0 to n-1.
        edges: A list of edges, where each edge is [vertex1, vertex2, cost].

    Returns:
        A tuple containing:
        - The total cost of the MST.
        - A list of edges that form the MST.
        Returns (None, []) if the graph is not connected.
    """
    # Create an adjacency list to represent the graph
    adj_list = collections.defaultdict(list)
    for u, v, cost in edges:
        adj_list[u].append((v, cost))
        adj_list[v].append((u, cost))

    # --- Initialization ---
    start_node = 0
    visited = {start_node}
    
    # The min_heap stores tuples of (cost, source, destination)
    min_heap = []
    for neighbor, cost in adj_list[start_node]:
        heapq.heappush(min_heap, (cost, start_node, neighbor))

    total_cost = 0
    mst_edges = []

    # --- Main Loop ---
    # The loop continues until we have V-1 edges in our MST
    while min_heap and len(mst_edges) < num_vertices - 1:
        # 1. Extract the edge with the minimum cost
        cost, u, v = heapq.heappop(min_heap)

        # 2. If the destination vertex is already visited, skip to avoid cycles
        if v in visited:
            continue

        # 3. Add the new vertex and edge to our MST
        visited.add(v)
        total_cost += cost
        mst_edges.append((u, v, cost))
        
        # 4. Add the new vertex's outgoing edges to the heap
        for neighbor, neighbor_cost in adj_list[v]:
            if neighbor not in visited:
                heapq.heappush(min_heap, (neighbor_cost, v, neighbor))

    # Check if a valid MST was formed (i.e., all vertices are connected)
    if len(mst_edges) == num_vertices - 1:
        return total_cost, mst_edges
    else:
        return None, [] # Graph is not connected


# --- Example Usage ---
if __name__ == "__main__":
    num_nodes = 5
    connections = [
        [0, 1, 10], [0, 2, 6], [0, 3, 5],
        [1, 3, 15], [2, 3, 4], [3, 4, 2]
    ]

    cost, mst = prims_mst(num_nodes, connections)

    if cost is not None:
        print(f"The total cost of the Minimum Spanning Tree is: {cost}")
        print("   The edges in the MST are:")
        for u, v, c in mst:
            print(f"     - Edge ({u} - {v}) with cost {c}")
    else:
        print("Could not form a Minimum Spanning Tree (the graph may not be connected).")