"""
LC 127 — Word Ladder.

Implicit graph: nodes = words, edges = pairs differing by exactly one letter.

Implementations:
    1. ladder_length_bucket_bfs       — bucket adjacency + BFS. Canonical.
    2. ladder_length_bidirectional_bfs — bidirectional BFS w/ "expand smaller frontier".
    3. ladder_length_brute_bfs        — O(N^2 * L) pairwise adjacency BFS. Oracle.

INVARIANT (BFS): distance counts WORDS in the path (including both endpoints),
                 so beginWord starts at distance 1.
INVARIANT (bucket): two words share a wildcard key (e.g., "h*t") iff they
                    differ at exactly one position.
INVARIANT (bidirectional): always expand the SMALLER frontier; intersection
                           check happens BEFORE expansion writes to next frontier.
INVARIANT: visited marked at enqueue, never at dequeue.
"""

from __future__ import annotations

import random
import string
from collections import defaultdict, deque


# ----------------------------------------------------------------------
# Solution A — Bucket BFS (canonical).
# ----------------------------------------------------------------------

def ladder_length_bucket_bfs(begin: str, end: str, word_list: list[str]) -> int:
    word_set = set(word_list)
    if end not in word_set:
        return 0
    if begin == end:
        return 1

    L = len(begin)
    buckets: dict[str, list[str]] = defaultdict(list)
    # Build buckets over wordList ∪ {begin} so the begin word can reach its neighbors.
    for w in word_set | {begin}:
        for i in range(L):
            buckets[w[:i] + "*" + w[i + 1:]].append(w)

    visited: set[str] = {begin}
    q: deque[tuple[str, int]] = deque([(begin, 1)])
    while q:
        word, dist = q.popleft()
        if word == end:
            return dist
        for i in range(L):
            key = word[:i] + "*" + word[i + 1:]
            for neigh in buckets[key]:
                if neigh not in visited:
                    visited.add(neigh)
                    q.append((neigh, dist + 1))
            buckets[key] = []  # process each bucket at most once
    return 0


# ----------------------------------------------------------------------
# Solution B — Bidirectional BFS.
# ----------------------------------------------------------------------

def ladder_length_bidirectional_bfs(begin: str, end: str, word_list: list[str]) -> int:
    word_set = set(word_list)
    if end not in word_set:
        return 0
    if begin == end:
        return 1

    L = len(begin)
    buckets: dict[str, list[str]] = defaultdict(list)
    for w in word_set | {begin}:
        for i in range(L):
            buckets[w[:i] + "*" + w[i + 1:]].append(w)

    # frontiers track {word: distance_from_that_end}
    forward: dict[str, int] = {begin: 1}
    backward: dict[str, int] = {end: 1}
    visited_fwd: set[str] = {begin}
    visited_bwd: set[str] = {end}

    while forward and backward:
        # Always expand the smaller frontier.
        if len(forward) > len(backward):
            forward, backward = backward, forward
            visited_fwd, visited_bwd = visited_bwd, visited_fwd

        next_frontier: dict[str, int] = {}
        for word, dist in forward.items():
            for i in range(L):
                key = word[:i] + "*" + word[i + 1:]
                for neigh in buckets[key]:
                    if neigh in backward:
                        return dist + backward[neigh]
                    if neigh not in visited_fwd:
                        visited_fwd.add(neigh)
                        next_frontier[neigh] = dist + 1
        forward = next_frontier
    return 0


# ----------------------------------------------------------------------
# Solution C — Brute BFS (pairwise adjacency). Oracle.
# ----------------------------------------------------------------------

def ladder_length_brute_bfs(begin: str, end: str, word_list: list[str]) -> int:
    word_set = set(word_list)
    if end not in word_set:
        return 0
    if begin == end:
        return 1

    nodes = list(word_set | {begin})
    idx = {w: i for i, w in enumerate(nodes)}

    def differ_by_one(a: str, b: str) -> bool:
        if len(a) != len(b):
            return False
        diffs = 0
        for ca, cb in zip(a, b):
            if ca != cb:
                diffs += 1
                if diffs > 1:
                    return False
        return diffs == 1

    n = len(nodes)
    adj: list[list[int]] = [[] for _ in range(n)]
    for i in range(n):
        for j in range(i + 1, n):
            if differ_by_one(nodes[i], nodes[j]):
                adj[i].append(j)
                adj[j].append(i)

    start, target = idx[begin], idx[end]
    visited = [False] * n
    visited[start] = True
    q: deque[tuple[int, int]] = deque([(start, 1)])
    while q:
        u, d = q.popleft()
        if u == target:
            return d
        for v in adj[u]:
            if not visited[v]:
                visited[v] = True
                q.append((v, d + 1))
    return 0


# ----------------------------------------------------------------------
# Helpers.
# ----------------------------------------------------------------------

def random_word(rng: random.Random, L: int, alphabet: str = "abcde") -> str:
    return "".join(rng.choice(alphabet) for _ in range(L))


def random_word_problem(rng: random.Random) -> tuple[str, str, list[str]]:
    L = rng.randint(2, 5)
    alphabet = "abcd"  # small so 1-letter neighbors are common
    n_words = rng.randint(2, 12)
    words = list({random_word(rng, L, alphabet) for _ in range(n_words)})
    begin = random_word(rng, L, alphabet)
    end = rng.choice(words)
    return begin, end, words


# ----------------------------------------------------------------------
# Stress test.
# ----------------------------------------------------------------------

def stress_test() -> None:
    rng = random.Random(42)
    n_iter = 200
    for _ in range(n_iter):
        begin, end, words = random_word_problem(rng)
        a = ladder_length_bucket_bfs(begin, end, words)
        b = ladder_length_bidirectional_bfs(begin, end, words)
        c = ladder_length_brute_bfs(begin, end, words)
        assert a == b == c, (
            f"Disagreement: bucket={a} bidir={b} brute={c}\n"
            f"begin={begin!r} end={end!r} words={words!r}"
        )
    print(f"stress_test: {n_iter} random dictionaries — bucket BFS, bidirectional BFS, "
          "and brute pairwise BFS all agree.")


# ----------------------------------------------------------------------
# Edge cases.
# ----------------------------------------------------------------------

def edge_cases() -> None:
    # LC canonical: hit → hot → dot → dog → cog (length 5).
    for fn in (ladder_length_bucket_bfs, ladder_length_bidirectional_bfs, ladder_length_brute_bfs):
        assert fn("hit", "cog", ["hot", "dot", "dog", "lot", "log", "cog"]) == 5, fn.__name__

    # LC canonical: endWord not in list (length 0).
    for fn in (ladder_length_bucket_bfs, ladder_length_bidirectional_bfs, ladder_length_brute_bfs):
        assert fn("hit", "cog", ["hot", "dot", "dog", "lot", "log"]) == 0, fn.__name__

    # begin == end (LC convention: return 1 since the path is just [begin]).
    for fn in (ladder_length_bucket_bfs, ladder_length_bidirectional_bfs, ladder_length_brute_bfs):
        assert fn("hit", "hit", ["hit"]) == 1, fn.__name__

    # Single-letter words: a → c via b.
    for fn in (ladder_length_bucket_bfs, ladder_length_bidirectional_bfs, ladder_length_brute_bfs):
        assert fn("a", "c", ["a", "b", "c"]) == 2, fn.__name__
        # ^ a→c is directly 1-edit (single letter differs), so distance 2 (a, c).

    # No path possible: end exists but is unreachable.
    # begin=aaa, end=zzz, intermediates only reach aaa→aab→aac etc, never z*.
    for fn in (ladder_length_bucket_bfs, ladder_length_bidirectional_bfs, ladder_length_brute_bfs):
        assert fn("aaa", "zzz", ["aab", "aac", "zzz"]) == 0, fn.__name__

    # Direct 1-edit: begin and end differ by one letter, end in list.
    for fn in (ladder_length_bucket_bfs, ladder_length_bidirectional_bfs, ladder_length_brute_bfs):
        assert fn("cat", "bat", ["bat"]) == 2, fn.__name__

    # Same word in list as begin (shouldn't cause issues).
    for fn in (ladder_length_bucket_bfs, ladder_length_bidirectional_bfs, ladder_length_brute_bfs):
        assert fn("cat", "bat", ["cat", "bat"]) == 2, fn.__name__

    print("edge_cases: PASS")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
