"""
p66 — Min Cost to Connect All Points (LeetCode 1584, MEDIUM).

MST on a complete graph with Manhattan distances.
- Prim with min_cost array: O(V^2) — best for dense graphs.
- Kruskal with sort + DSU: O(E log E) = O(V^2 log V) — fine but slower here.
- Brute: enumerate all spanning trees via Prufer sequences (n^(n-2)). n <= 6.
"""
from __future__ import annotations
import heapq
import itertools
import random
from typing import List


INF = float("inf")


def _dist(a: List[int], b: List[int]) -> int:
    return abs(a[0] - b[0]) + abs(a[1] - b[1])


def min_cost_connect_prim_array(points: List[List[int]]) -> int:
    n = len(points)
    if n <= 1:
        return 0
    # INVARIANT: min_cost[v] = cheapest edge from {in_mst} to v (INF if unreachable yet).
    min_cost = [INF] * n
    in_mst = [False] * n
    min_cost[0] = 0
    total = 0
    for _ in range(n):
        u = -1
        best = INF
        for v in range(n):
            if not in_mst[v] and min_cost[v] < best:
                best = min_cost[v]
                u = v
        in_mst[u] = True
        total += int(best)
        for v in range(n):
            if not in_mst[v]:
                d = _dist(points[u], points[v])
                if d < min_cost[v]:
                    min_cost[v] = d
    return total


class _DSU:
    __slots__ = ("parent", "rank")

    def __init__(self, n: int) -> None:
        self.parent = list(range(n))
        self.rank = [0] * n

    def find(self, x: int) -> int:
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x: int, y: int) -> bool:
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return False
        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
        self.parent[ry] = rx
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1
        return True


def min_cost_connect_kruskal(points: List[List[int]]) -> int:
    n = len(points)
    if n <= 1:
        return 0
    edges = []
    for i in range(n):
        for j in range(i + 1, n):
            edges.append((_dist(points[i], points[j]), i, j))
    edges.sort()
    dsu = _DSU(n)
    total = 0
    used = 0
    for w, u, v in edges:
        if dsu.union(u, v):
            total += w
            used += 1
            if used == n - 1:
                break
    return total


def min_cost_connect_prim_heap(points: List[List[int]]) -> int:
    n = len(points)
    if n <= 1:
        return 0
    in_mst = [False] * n
    heap = [(0, 0)]   # (cost, node)
    total = 0
    added = 0
    while heap and added < n:
        cost, u = heapq.heappop(heap)
        if in_mst[u]:
            continue
        in_mst[u] = True
        total += cost
        added += 1
        for v in range(n):
            if not in_mst[v]:
                heapq.heappush(heap, (_dist(points[u], points[v]), v))
    return total


def _spanning_trees(n: int):
    """Yield every spanning tree (as set of frozenset edges) via Prufer sequences."""
    if n == 1:
        yield set()
        return
    if n == 2:
        yield {frozenset((0, 1))}
        return
    for seq in itertools.product(range(n), repeat=n - 2):
        degree = [1] * n
        for x in seq:
            degree[x] += 1
        edges = set()
        seq_list = list(seq)
        for x in seq_list:
            for v in range(n):
                if degree[v] == 1:
                    edges.add(frozenset((x, v)))
                    degree[v] -= 1
                    degree[x] -= 1
                    break
        leaves = [v for v in range(n) if degree[v] == 1]
        edges.add(frozenset((leaves[0], leaves[1])))
        yield edges


def min_cost_connect_brute(points: List[List[int]]) -> int:
    n = len(points)
    if n <= 1:
        return 0
    best = INF
    for tree in _spanning_trees(n):
        total = 0
        for e in tree:
            i, j = tuple(e) if len(e) == 2 else (next(iter(e)), next(iter(e)))
            total += _dist(points[i], points[j])
        if total < best:
            best = total
    return int(best)


def edge_cases() -> None:
    assert min_cost_connect_prim_array([[0, 0], [2, 2], [3, 10], [5, 2], [7, 0]]) == 20
    assert min_cost_connect_kruskal([[0, 0], [2, 2], [3, 10], [5, 2], [7, 0]]) == 20
    assert min_cost_connect_prim_heap([[0, 0], [2, 2], [3, 10], [5, 2], [7, 0]]) == 20
    assert min_cost_connect_prim_array([[3, 12], [-2, 5], [-4, 1]]) == 18
    assert min_cost_connect_prim_array([[0, 0]]) == 0
    assert min_cost_connect_prim_array([[0, 0], [1, 1]]) == 2
    assert min_cost_connect_prim_array([[0, 0], [0, 0]]) == 0   # coincident points
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(200):
        n = rng.randint(1, 6)
        pts = [[rng.randint(-5, 5), rng.randint(-5, 5)] for _ in range(n)]
        a = min_cost_connect_prim_array(pts)
        b = min_cost_connect_kruskal(pts)
        c = min_cost_connect_prim_heap(pts)
        d = min_cost_connect_brute(pts)
        assert a == b == c == d, f"trial {trial}: pts={pts} prim={a} kruskal={b} heap={c} brute={d}"
    print("stress_test: 200 trials — array-Prim, Kruskal, heap-Prim, brute (Prufer) all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
