"""
p73 — Interleaving String (LeetCode 97, HARD).

dp[i][j] = True iff s3[:i+j] is an interleaving of s1[:i] and s2[:j].
Recurrence:
  dp[i][j] = (dp[i-1][j] and s1[i-1]==s3[i+j-1])
          or (dp[i][j-1] and s2[j-1]==s3[i+j-1])
"""
from __future__ import annotations
import random
from functools import lru_cache


def is_interleave_dp(s1: str, s2: str, s3: str) -> bool:
    m, n = len(s1), len(s2)
    if m + n != len(s3):
        return False
    dp = [[False] * (n + 1) for _ in range(m + 1)]
    dp[0][0] = True
    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0 and j == 0:
                continue
            from_s1 = i > 0 and dp[i - 1][j] and s1[i - 1] == s3[i + j - 1]
            from_s2 = j > 0 and dp[i][j - 1] and s2[j - 1] == s3[i + j - 1]
            dp[i][j] = from_s1 or from_s2
    return dp[m][n]


def is_interleave_1d(s1: str, s2: str, s3: str) -> bool:
    m, n = len(s1), len(s2)
    if m + n != len(s3):
        return False
    # INVARIANT: ensure inner loop is over the shorter dim for O(min(m,n)) memory.
    if n < m:
        s1, s2 = s2, s1
        m, n = n, m
    dp = [False] * (n + 1)
    dp[0] = True
    for j in range(1, n + 1):
        dp[j] = dp[j - 1] and s2[j - 1] == s3[j - 1]
    for i in range(1, m + 1):
        dp[0] = dp[0] and s1[i - 1] == s3[i - 1]
        for j in range(1, n + 1):
            from_s1 = dp[j] and s1[i - 1] == s3[i + j - 1]
            from_s2 = dp[j - 1] and s2[j - 1] == s3[i + j - 1]
            dp[j] = from_s1 or from_s2
    return dp[n]


def is_interleave_memo(s1: str, s2: str, s3: str) -> bool:
    if len(s1) + len(s2) != len(s3):
        return False

    @lru_cache(maxsize=None)
    def solve(i: int, j: int) -> bool:
        if i == len(s1) and j == len(s2):
            return True
        k = i + j
        if i < len(s1) and s1[i] == s3[k] and solve(i + 1, j):
            return True
        if j < len(s2) and s2[j] == s3[k] and solve(i, j + 1):
            return True
        return False

    return solve(0, 0)


def is_interleave_brute(s1: str, s2: str, s3: str) -> bool:
    # Enumerate all interleavings (exponential). Oracle for small inputs.
    if len(s1) + len(s2) != len(s3):
        return False

    def solve(i: int, j: int, built: str) -> bool:
        if i == len(s1) and j == len(s2):
            return built == s3
        # prune by length
        if len(built) > len(s3) or s3[: len(built)] != built:
            return False
        if i < len(s1) and solve(i + 1, j, built + s1[i]):
            return True
        if j < len(s2) and solve(i, j + 1, built + s2[j]):
            return True
        return False

    return solve(0, 0, "")


def edge_cases() -> None:
    assert is_interleave_dp("aabcc", "dbbca", "aadbbcbcac") is True
    assert is_interleave_dp("aabcc", "dbbca", "aadbbbaccc") is False
    assert is_interleave_dp("", "", "") is True
    assert is_interleave_dp("", "abc", "abc") is True
    assert is_interleave_dp("abc", "", "abc") is True
    assert is_interleave_dp("a", "b", "ab") is True
    assert is_interleave_dp("a", "b", "ba") is True
    assert is_interleave_dp("a", "b", "ab") == is_interleave_1d("a", "b", "ab")
    # Greedy counterexample: greedy at 'b' could pick s1 or s2
    assert is_interleave_dp("aabc", "abad", "aaabbacd") == is_interleave_brute("aabc", "abad", "aaabbacd")
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "ab"
    for trial in range(300):
        s1 = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 4)))
        s2 = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 4)))
        # generate s3: half real interleavings, half random
        if rng.random() < 0.5:
            # construct a real interleaving
            i = j = 0
            s3_chars = []
            while i < len(s1) or j < len(s2):
                if i < len(s1) and (j >= len(s2) or rng.random() < 0.5):
                    s3_chars.append(s1[i]); i += 1
                else:
                    s3_chars.append(s2[j]); j += 1
            s3 = "".join(s3_chars)
        else:
            s3 = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 8)))
        a = is_interleave_dp(s1, s2, s3)
        b = is_interleave_1d(s1, s2, s3)
        c = is_interleave_memo(s1, s2, s3)
        d = is_interleave_brute(s1, s2, s3)
        assert a == b == c == d, f"trial {trial}: s1={s1!r} s2={s2!r} s3={s3!r} dp={a} 1d={b} memo={c} brute={d}"
    print("stress_test: 300 trials — iterative DP, 1D rolling, memoized, brute all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
