"""
p65 — Number of Connected Components in an Undirected Graph (LeetCode 323, MEDIUM).

DSU with path halving + union by rank. O((V + E) * alpha(V)).
"""
from __future__ import annotations
import random
from collections import defaultdict
from typing import List


class DSU:
    __slots__ = ("parent", "rank", "components")

    def __init__(self, n: int) -> None:
        self.parent = list(range(n))
        self.rank = [0] * n
        self.components = n

    def find(self, x: int) -> int:
        # INVARIANT: path halving — each step replaces parent[x] with grandparent
        # then advances, keeping trees shallow in amortized alpha(n) time.
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x: int, y: int) -> bool:
        rx, ry = self.find(x), self.find(y)
        if rx == ry:
            return False
        if self.rank[rx] < self.rank[ry]:
            rx, ry = ry, rx
        self.parent[ry] = rx
        if self.rank[rx] == self.rank[ry]:
            self.rank[rx] += 1
        self.components -= 1
        return True


def count_components_dsu(n: int, edges: List[List[int]]) -> int:
    dsu = DSU(n)
    for u, v in edges:
        dsu.union(u, v)
    return dsu.components


def count_components_dfs(n: int, edges: List[List[int]]) -> int:
    adj: defaultdict = defaultdict(list)
    for u, v in edges:
        adj[u].append(v)
        adj[v].append(u)
    seen = [False] * n
    count = 0
    for start in range(n):
        if seen[start]:
            continue
        count += 1
        stack = [start]
        seen[start] = True
        while stack:
            x = stack.pop()
            for y in adj[x]:
                if not seen[y]:
                    seen[y] = True
                    stack.append(y)
    return count


def count_components_brute(n: int, edges: List[List[int]]) -> int:
    # Transitive-closure flood: repeatedly merge until stable.
    reach = [set([i]) for i in range(n)]
    for u, v in edges:
        merged = reach[u] | reach[v]
        for x in merged:
            reach[x] = merged
    seen = set()
    count = 0
    for i in range(n):
        rid = id(reach[i])
        if rid not in seen:
            seen.add(rid)
            count += 1
    return count


def edge_cases() -> None:
    assert count_components_dsu(5, [[0, 1], [1, 2], [3, 4]]) == 2
    assert count_components_dsu(5, [[0, 1], [1, 2], [2, 3], [3, 4]]) == 1
    assert count_components_dsu(5, []) == 5
    assert count_components_dsu(1, []) == 1
    assert count_components_dsu(3, [[0, 0]]) == 3   # self-loop changes nothing
    assert count_components_dsu(3, [[0, 1], [0, 1]]) == 2   # duplicate edge
    assert count_components_dfs(5, [[0, 1], [1, 2], [3, 4]]) == 2
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(500):
        n = rng.randint(1, 10)
        m = rng.randint(0, n * (n - 1) // 2 + 1)
        edges = []
        for _ in range(m):
            u = rng.randrange(n)
            v = rng.randrange(n)
            edges.append([u, v])
        a = count_components_dsu(n, edges)
        b = count_components_dfs(n, edges)
        c = count_components_brute(n, edges)
        assert a == b == c, f"trial {trial}: n={n} edges={edges} dsu={a} dfs={b} brute={c}"
    print("stress_test: 500 trials — DSU, DFS, brute transitive-closure all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
