"""
p74 — Distinct Subsequences (LeetCode 115, HARD).

dp[i][j] = number of distinct subsequences of s[:i] equal to t[:j].
Recurrence:
  if s[i-1] != t[j-1]: dp[i][j] = dp[i-1][j]
  else:                dp[i][j] = dp[i-1][j-1] + dp[i-1][j]
Base: dp[i][0] = 1 (one way: empty subsequence); dp[0][j>=1] = 0.
"""
from __future__ import annotations
import random
from functools import lru_cache
from itertools import combinations


def num_distinct_dp(s: str, t: str) -> int:
    m, n = len(s), len(t)
    if n > m:
        return 0
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    for i in range(m + 1):
        dp[i][0] = 1
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            dp[i][j] = dp[i - 1][j]
            if s[i - 1] == t[j - 1]:
                dp[i][j] += dp[i - 1][j - 1]
    return dp[m][n]


def num_distinct_1d(s: str, t: str) -> int:
    m, n = len(s), len(t)
    if n > m:
        return 0
    dp = [0] * (n + 1)
    dp[0] = 1
    for i in range(1, m + 1):
        # INVARIANT: iterate j right-to-left so dp[j-1] still holds previous-row value.
        for j in range(n, 0, -1):
            if s[i - 1] == t[j - 1]:
                dp[j] += dp[j - 1]
    return dp[n]


def num_distinct_memo(s: str, t: str) -> int:
    @lru_cache(maxsize=None)
    def solve(i: int, j: int) -> int:
        if j == len(t):
            return 1
        if i == len(s):
            return 0
        total = solve(i + 1, j)  # skip s[i]
        if s[i] == t[j]:
            total += solve(i + 1, j + 1)  # use s[i] for t[j]
        return total
    return solve(0, 0)


def num_distinct_brute(s: str, t: str) -> int:
    # Enumerate all C(|s|, |t|) index-subsets; count those whose chars form t.
    count = 0
    for idxs in combinations(range(len(s)), len(t)):
        if all(s[idxs[k]] == t[k] for k in range(len(t))):
            count += 1
    return count


def edge_cases() -> None:
    assert num_distinct_dp("rabbbit", "rabbit") == 3
    assert num_distinct_dp("babgbag", "bag") == 5
    assert num_distinct_dp("", "") == 1
    assert num_distinct_dp("abc", "") == 1
    assert num_distinct_dp("", "abc") == 0
    assert num_distinct_dp("aa", "a") == 2
    assert num_distinct_dp("abc", "abc") == 1
    assert num_distinct_dp("abc", "abcd") == 0
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "ab"
    for trial in range(300):
        s = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 7)))
        t = "".join(rng.choice(alphabet) for _ in range(rng.randint(0, 4)))
        a = num_distinct_dp(s, t)
        b = num_distinct_1d(s, t)
        c = num_distinct_memo(s, t)
        d = num_distinct_brute(s, t)
        assert a == b == c == d, f"trial {trial}: s={s!r} t={t!r} dp={a} 1d={b} memo={c} brute={d}"
    print("stress_test: 300 trials — iterative DP, 1D rolling, memoized, combinations brute all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
