"""
p59 — Longest Common Subsequence (LeetCode 1143, MEDIUM).

dp[i][j] = LCS of text1[:i], text2[:j]
match    -> dp[i-1][j-1] + 1
mismatch -> max(dp[i-1][j], dp[i][j-1])
"""
from __future__ import annotations
import random
from typing import List


def lcs_2d(text1: str, text2: str) -> int:
    m, n = len(text1), len(text2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if text1[i - 1] == text2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
    return dp[m][n]


def lcs_1d(text1: str, text2: str) -> int:
    # Use shorter string on the inner dimension to minimize memory.
    if len(text2) > len(text1):
        text1, text2 = text2, text1
    m, n = len(text1), len(text2)
    dp = [0] * (n + 1)
    for i in range(1, m + 1):
        prev_diag = 0   # dp[i-1][0]
        for j in range(1, n + 1):
            tmp = dp[j]   # this will become prev_diag for next column
            if text1[i - 1] == text2[j - 1]:
                dp[j] = prev_diag + 1
            else:
                dp[j] = max(dp[j], dp[j - 1])
            prev_diag = tmp
    return dp[n]


def lcs_reconstruct(text1: str, text2: str) -> str:
    m, n = len(text1), len(text2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if text1[i - 1] == text2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])
    i, j = m, n
    out: List[str] = []
    while i > 0 and j > 0:
        if text1[i - 1] == text2[j - 1]:
            out.append(text1[i - 1])
            i -= 1
            j -= 1
        elif dp[i - 1][j] >= dp[i][j - 1]:
            i -= 1
        else:
            j -= 1
    out.reverse()
    return "".join(out)


def lcs_brute(text1: str, text2: str) -> int:
    # All subsequences of text1; check if each is a subsequence of text2.
    best = 0
    n1 = len(text1)
    for mask in range(1 << n1):
        sub = "".join(text1[i] for i in range(n1) if mask & (1 << i))
        # Is sub a subseq of text2?
        k = 0
        for ch in text2:
            if k < len(sub) and ch == sub[k]:
                k += 1
        if k == len(sub):
            best = max(best, len(sub))
    return best


def _is_common_subseq(sub: str, a: str, b: str) -> bool:
    for s, target in [(a, sub), (b, sub)]:
        # ensure `target` is a subseq of `s`
        k = 0
        for ch in s:
            if k < len(target) and ch == target[k]:
                k += 1
        if k != len(target):
            return False
    return True


def edge_cases() -> None:
    assert lcs_2d("abcde", "ace") == 3
    assert lcs_1d("abcde", "ace") == 3
    assert lcs_2d("abc", "abc") == 3
    assert lcs_2d("abc", "def") == 0
    assert lcs_2d("", "abc") == 0
    assert lcs_2d("abc", "") == 0
    assert lcs_2d("", "") == 0
    assert lcs_2d("a", "a") == 1
    rec = lcs_reconstruct("abcde", "ace")
    assert rec == "ace"
    rec2 = lcs_reconstruct("abc", "def")
    assert rec2 == ""
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "abc"
    for trial in range(500):
        a = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 8)))
        b = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 8)))
        x = lcs_2d(a, b)
        y = lcs_1d(a, b)
        z = lcs_brute(a, b)
        rec = lcs_reconstruct(a, b)
        assert x == y == z == len(rec), f"trial {trial}: a={a!r} b={b!r} 2d={x} 1d={y} brute={z} rec={rec!r}"
        if rec:
            assert _is_common_subseq(rec, a, b), f"rec invalid: a={a!r} b={b!r} rec={rec!r}"
    print("stress_test: 500 trials — 2d, 1d, brute, reconstruct agree and rec is valid.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
