"""
p77 — Strange Printer (LeetCode 664, HARD).

Interval DP. dp[i][j] = min printer turns for s[i..j].
Base: dp[i][i] = 1.
Recurrence:
    dp[i][j] = dp[i+1][j] + 1  (print s[i] alone, then s[i+1..j])
    for k in i+1..j: if s[k] == s[i]:
        dp[i][j] = min(dp[i][j], dp[i+1][k-1] + dp[k][j])
(savings: the stroke painting s[i] extends to cover s[k] -- no extra turn at k.)

Total O(n^3).
"""
from __future__ import annotations
import random
from functools import lru_cache


def _collapse(s: str) -> str:
    if not s:
        return s
    out = [s[0]]
    for c in s[1:]:
        if c != out[-1]:
            out.append(c)
    return "".join(out)


def strange_printer_dp(s: str) -> int:
    s = _collapse(s)
    n = len(s)
    if n == 0:
        return 0
    dp = [[0] * n for _ in range(n)]
    for i in range(n):
        dp[i][i] = 1
    # length-major fill
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            dp[i][j] = dp[i + 1][j] + 1  # print s[i] alone, then rest
            for k in range(i + 1, j + 1):
                if s[k] == s[i]:
                    left = dp[i + 1][k - 1] if k - 1 >= i + 1 else 0
                    dp[i][j] = min(dp[i][j], left + dp[k][j])
    return dp[0][n - 1]


def strange_printer_memo(s: str) -> int:
    s = _collapse(s)
    n = len(s)
    if n == 0:
        return 0

    @lru_cache(maxsize=None)
    def solve(i: int, j: int) -> int:
        if i > j:
            return 0
        if i == j:
            return 1
        best = solve(i + 1, j) + 1
        for k in range(i + 1, j + 1):
            if s[k] == s[i]:
                best = min(best, solve(i + 1, k - 1) + solve(k, j))
        return best

    return solve(0, n - 1)


def strange_printer_brute(s: str) -> int:
    # Independent oracle: RIGHT-anchored interval DP.
    # dp[i][j] = dp[i][j-1] + 1, then for k in [i..j-1] with s[k]==s[j]:
    #   dp[i][j] = min(dp[i][j], dp[i][k] + dp[k+1][j-1])
    # Symmetric to the main left-anchored recurrence; must produce identical answer.
    n = len(s)
    if n == 0:
        return 0

    @lru_cache(maxsize=None)
    def solve(i: int, j: int) -> int:
        if i > j:
            return 0
        if i == j:
            return 1
        best = solve(i, j - 1) + 1
        for k in range(i, j):
            if s[k] == s[j]:
                left = solve(i, k)
                right = solve(k + 1, j - 1) if k + 1 <= j - 1 else 0
                best = min(best, left + right)
        return best

    return solve(0, n - 1)


def edge_cases() -> None:
    assert strange_printer_dp("aaabbb") == 2
    assert strange_printer_dp("aba") == 2
    assert strange_printer_dp("a") == 1
    assert strange_printer_dp("") == 0
    assert strange_printer_dp("aabbaa") == 2
    assert strange_printer_dp("abcabc") == 5
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "abc"
    for trial in range(300):
        n = rng.randint(0, 7)
        s = "".join(rng.choice(alphabet) for _ in range(n))
        a = strange_printer_dp(s)
        b = strange_printer_memo(s)
        c = strange_printer_brute(_collapse(s))
        assert a == b == c, f"trial {trial}: s={s!r} dp={a} memo={b} brute={c}"
    print("stress_test: 300 trials — iterative interval DP, memoized, brute split+absorb all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
