"""
p75 — Palindrome Partitioning II (LeetCode 132, HARD).

Two-phase O(n^2):
  Phase 1: is_pal[i][j] table by increasing length.
  Phase 2: cuts[i] = min cuts for s[:i+1].
           If is_pal[0][i]: cuts[i] = 0.
           Else: cuts[i] = min over j in 1..i where is_pal[j][i] of (cuts[j-1] + 1).
"""
from __future__ import annotations
import random


def min_cut_dp(s: str) -> int:
    n = len(s)
    if n <= 1:
        return 0
    # Phase 1: palindrome table by length.
    is_pal = [[False] * n for _ in range(n)]
    for i in range(n):
        is_pal[i][i] = True
    for length in range(2, n + 1):
        for i in range(n - length + 1):
            j = i + length - 1
            if s[i] == s[j] and (length == 2 or is_pal[i + 1][j - 1]):
                is_pal[i][j] = True
    # Phase 2: cuts DP.
    cuts = [0] * n
    for i in range(n):
        if is_pal[0][i]:
            cuts[i] = 0
        else:
            # INVARIANT: i needs at least one cut; try each j..i palindrome ending at i.
            cuts[i] = i  # worst case: cut between every char
            for j in range(1, i + 1):
                if is_pal[j][i]:
                    cuts[i] = min(cuts[i], cuts[j - 1] + 1)
    return cuts[n - 1]


def min_cut_expand(s: str) -> int:
    # Expand-around-center: O(n^2) time, O(n) memory.
    n = len(s)
    if n <= 1:
        return 0
    cuts = list(range(-1, n))  # cuts[i+1] = min cuts for s[:i+1]; cuts[0] = -1 sentinel.

    def expand(l: int, r: int) -> None:
        while l >= 0 and r < n and s[l] == s[r]:
            # s[l..r] palindrome; update cuts[r+1].
            cuts[r + 1] = min(cuts[r + 1], cuts[l] + 1)
            l -= 1
            r += 1

    for c in range(n):
        expand(c, c)       # odd-length center
        expand(c, c + 1)   # even-length center
    return cuts[n]


def min_cut_brute(s: str) -> int:
    # Enumerate all 2^(n-1) cut placements. Oracle for small inputs.
    n = len(s)
    if n <= 1:
        return 0
    best = n - 1
    # Each bit in mask of length n-1 = cut after that position
    for mask in range(1 << (n - 1)):
        parts = []
        last = 0
        for i in range(n - 1):
            if mask & (1 << i):
                parts.append(s[last : i + 1])
                last = i + 1
        parts.append(s[last:])
        if all(p == p[::-1] for p in parts):
            best = min(best, len(parts) - 1)
    return best


def edge_cases() -> None:
    assert min_cut_dp("aab") == 1
    assert min_cut_dp("a") == 0
    assert min_cut_dp("ab") == 1
    assert min_cut_dp("aabbaa") == 0  # already palindrome
    assert min_cut_dp("abcd") == 3
    assert min_cut_dp("aaaa") == 0
    assert min_cut_dp("") == 0
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    alphabet = "abc"
    for trial in range(300):
        n = rng.randint(0, 8)
        s = "".join(rng.choice(alphabet) for _ in range(n))
        a = min_cut_dp(s)
        b = min_cut_expand(s)
        c = min_cut_brute(s)
        assert a == b == c, f"trial {trial}: s={s!r} dp={a} expand={b} brute={c}"
    print("stress_test: 300 trials — two-phase DP, expand-around-center, brute partition all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
