Kruskal's Algorithm Complete Code

Chris Tralie

In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
from unionfind import *


class GraphNode:
    def __init__(self):
        self.edges = []
        self.data = {}

def draw_2d_graph(nodes, edges, draw_nodes=True, draw_labels=False, linewidth=2):
    ax = plt.gca()
    ax.set_facecolor((0.9, 0.9, 0.9))
    for (i, j, d) in edges:
        x1, y1 = nodes[i].data['x'], nodes[i].data['y']
        x2, y2 = nodes[j].data['x'], nodes[j].data['y']
        plt.plot([x1, x2], [y1, y2], linewidth=linewidth)
    for i, n in enumerate(nodes):
        if draw_nodes:
            plt.scatter(n.data['x'], n.data['y'], 100, c='k')
            if draw_labels:
                plt.text(n.data['x']+0.002, n.data['y']+0.002, "{}".format(i), zorder=10, c='r', fontsize='xx-large')

def dist_of_edge(e):
    return e[2]



def get_mst_kruskal(nodes, edges):
    edges = sorted(edges, key = dist_of_edge)
    djset = UFFast(len(nodes))
    new_edges = []
    for e in edges:
        (i, j, d) = e
        if not djset.find(i, j):
            djset.union(i, j)
            new_edges.append(e)
    return new_edges

Complexity of Kruskal's Algorithm / Euclidean MSTs

For a graph that has $E$ edges and $N$ nodes, the worst-case time complexity of Kruskal's algorithm is $O(E \log E)$. This is because of the dominant step of sorting the edges in increasing order of distance with a comparison-based sort. The step of using union find is actually quite fast; over all $E$ edges, it takes $E \alpha(E)$ time where $\alpha(E)$ is the inverse Ackermann function (due to rank-based balancing and path compression), which, for all practical purposes, is a constant.

For a Euclidean MST in which the nodes correspond to points chosen in the plane and we have the complete graph on all $N$ nodes, then there are $N(N+1)/2$ or $O(N^2)$ edges in the whole graph. This means that Kruskal's algorithm takes $O(N^2 \log (N^2)) = O(N^2 2 \log(N)) = O(N^2 \log N)$ time.

However, we can do better than this for Euclidean MSTs if we use a graph arising from a Delaunay Triangulation of the points. The edges Delaunay Triangulation form a planar graph which has $O(N)$ edges, and it can be shown that a subset of the Delaunay edges forms a minimum spanning tree. Therefore, we can narrow down to $O(N)$ edges before applying Kruskals' algorithm, which means the total complexity goes down to $O(N \log(N))$. The code below shows how to do this using scipy

In [2]:
from scipy.spatial import Delaunay

def make_delaunay_graph(N):
    x = np.random.rand(N)
    y = np.random.rand(N)
    nodes = []
    for i in range(N):
        n = GraphNode()
        n.data = {'x':x[i], 'y':y[i]}
        nodes.append(n)
    tri = Delaunay(np.array([x, y]).T).simplices
    edges = set()
    for i in range(tri.shape[0]):
        for k in range(3):
            i1, i2 = tri[i, k], tri[i, (k+1)%3]
            d = np.sqrt(np.sum((x[i1]-x[i2])**2 + (y[i1]-y[i2])**2))
            edges.add((i1, i2, d))
    return nodes, list(edges)


np.random.seed(0)
nodes, edges = make_delaunay_graph(20)
new_edges = get_mst_kruskal(nodes, edges)

plt.figure(figsize=(10, 15))
plt.subplot(211)
draw_2d_graph(nodes, edges)
plt.subplot(212)
draw_2d_graph(nodes, new_edges)

Mazes

One cool application of spanning trees is the automatic creation of mazes. If we create a graph in which the nodes are on a grid and the edges connect left/right/up/down neighbors, and we assign random distances to all edge lengths, then running Kruskal's algorithm yields edges which form a maze. Since it yields a spanning tree, we can find a path between any two pairs of vertices. If the start vertex and the end vertex are on the boundary of the grid, then we can create a traditional maze

In [3]:
def make_grid_graph(N, seed = 0):
    """
    Parameters
    ----------
    N: int
        Resolution of grid
    """
    np.random.seed(seed)
    nodes = []
    for i in range(N):
        for j in range(N):
            n = GraphNode()
            n.data = {'x':j, 'y':i}
            nodes.append(n)
    edges = []
    neighbs = [[-1, 0], [1, 0], [0, -1], [0, 1]]
    for i in range(N):
        for j in range(N):
            idx1 = i*N + j
            for [di, dj] in neighbs:
                ii = i + di
                jj = j + dj
                if ii >= 0 and jj >= 0 and ii < N and jj < N:
                    idx2 = ii*N + jj
                    edges.append((idx1, idx2, np.random.rand()))
    return nodes, edges

def draw_grid_edges_image(edges, N):
    I = np.zeros((2*N-1, 2*N-1))
    for (i, j, d) in edges:
        x1, y1 = i%N, i//N
        x2, y2 = j%N, j//N
        x1, x2 = min(x1, x2), max(x1, x2)
        y1, y2 = min(y1, y2), max(y1, y2)
        if np.abs(x1-x2) > 0:
            # Horizontal line
            I[2*y1, 2*x1:2*x2+1] = 1
        else:
            I[2*y1:2*y2+1, 2*x1] = 1
    plt.imshow(I)
    plt.gca().invert_yaxis()
    plt.axis('off')

N = 10
nodes, edges = make_grid_graph(N)
new_edges = get_mst_kruskal(nodes, edges)
plt.figure(figsize=(10, 5))
plt.subplot(121)
draw_2d_graph(nodes, new_edges, linewidth=3)
plt.subplot(122)
draw_grid_edges_image(new_edges, N)
In [4]:
plt.figure(figsize=(10, 10))
N = 100
nodes, edges = make_grid_graph(N)
new_edges = get_mst_kruskal(nodes, edges)
draw_grid_edges_image(new_edges, N)
plt.show()
In [ ]: