"""
p60 — Edit Distance / Levenshtein (LeetCode 72, HARD).

dp[i][j] = edit distance(word1[:i], word2[:j]).
Match    -> dp[i-1][j-1]
Else     -> 1 + min(dp[i-1][j-1] sub, dp[i-1][j] del, dp[i][j-1] ins).
Base     -> dp[i][0] = i; dp[0][j] = j.
"""
from __future__ import annotations
import random
from functools import lru_cache
from typing import List, Tuple


def min_distance_2d(word1: str, word2: str) -> int:
    m, n = len(word1), len(word2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if word1[i - 1] == word2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(
                    dp[i - 1][j - 1],   # substitute
                    dp[i - 1][j],       # delete from word1
                    dp[i][j - 1],       # insert into word1
                )
    return dp[m][n]


def min_distance_1d(word1: str, word2: str) -> int:
    if len(word2) > len(word1):
        word1, word2 = word2, word1   # symmetry; shorter on inner dim
    m, n = len(word1), len(word2)
    dp = list(range(n + 1))
    for i in range(1, m + 1):
        prev_diag = dp[0]    # dp[i-1][0]
        dp[0] = i            # base for current row
        for j in range(1, n + 1):
            tmp = dp[j]
            if word1[i - 1] == word2[j - 1]:
                dp[j] = prev_diag
            else:
                dp[j] = 1 + min(prev_diag, dp[j], dp[j - 1])
            prev_diag = tmp
    return dp[n]


def min_distance_reconstruct(word1: str, word2: str) -> List[Tuple[str, str]]:
    """Return list of (op, detail) where op in {'match','sub','del','ins'}."""
    m, n = len(word1), len(word2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if word1[i - 1] == word2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i - 1][j - 1], dp[i - 1][j], dp[i][j - 1])
    i, j = m, n
    ops: List[Tuple[str, str]] = []
    while i > 0 or j > 0:
        if i > 0 and j > 0 and word1[i - 1] == word2[j - 1]:
            ops.append(("match", word1[i - 1])); i -= 1; j -= 1
        elif i > 0 and j > 0 and dp[i][j] == dp[i - 1][j - 1] + 1:
            ops.append(("sub", f"{word1[i-1]}->{word2[j-1]}")); i -= 1; j -= 1
        elif i > 0 and dp[i][j] == dp[i - 1][j] + 1:
            ops.append(("del", word1[i - 1])); i -= 1
        else:
            ops.append(("ins", word2[j - 1])); j -= 1
    ops.reverse()
    return ops


def min_distance_brute(word1: str, word2: str) -> int:
    @lru_cache(maxsize=None)
    def f(i: int, j: int) -> int:
        if i == 0:
            return j
        if j == 0:
            return i
        if word1[i - 1] == word2[j - 1]:
            return f(i - 1, j - 1)
        return 1 + min(f(i - 1, j - 1), f(i - 1, j), f(i, j - 1))
    return f(len(word1), len(word2))


def _apply_ops(word1: str, ops: List[Tuple[str, str]]) -> str:
    out: List[str] = []
    i = 0
    for op, detail in ops:
        if op == "match":
            out.append(word1[i]); i += 1
        elif op == "sub":
            out.append(detail.split("->")[1]); i += 1
        elif op == "del":
            i += 1
        elif op == "ins":
            out.append(detail)
    return "".join(out)


def edge_cases() -> None:
    assert min_distance_2d("horse", "ros") == 3
    assert min_distance_1d("horse", "ros") == 3
    assert min_distance_2d("intention", "execution") == 5
    assert min_distance_2d("", "abc") == 3
    assert min_distance_2d("abc", "") == 3
    assert min_distance_2d("", "") == 0
    assert min_distance_2d("abc", "abc") == 0
    assert min_distance_2d("a", "b") == 1
    ops = min_distance_reconstruct("horse", "ros")
    assert _apply_ops("horse", ops) == "ros"
    cost = sum(1 for op, _ in ops if op != "match")
    assert cost == 3
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "abc"
    for trial in range(500):
        a = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 7)))
        b = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 7)))
        x = min_distance_2d(a, b)
        y = min_distance_1d(a, b)
        z = min_distance_brute(a, b)
        assert x == y == z, f"trial {trial}: a={a!r} b={b!r} 2d={x} 1d={y} brute={z}"
        ops = min_distance_reconstruct(a, b)
        applied = _apply_ops(a, ops)
        assert applied == b, f"reconstruct invalid: a={a!r} b={b!r} ops={ops} applied={applied!r}"
        cost = sum(1 for op, _ in ops if op != "match")
        assert cost == x, f"reconstruct cost {cost} != dp {x} for a={a!r} b={b!r}"
    print("stress_test: 500 trials — 2d, 1d, brute agree; reconstruct yields correct script.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
