"""
p71 — Burst Balloons (LeetCode 312, HARD).

Pad nums with 1 on each end. dp[i][j] = max coins from bursting all balloons
strictly between i and j (exclusive endpoints).

Key insight: enumerate k = the LAST balloon to burst in (i,j). When k bursts,
its neighbors are exactly nums[i] and nums[j] (everything else gone), giving
the clean recurrence:
    dp[i][j] = max over k in (i,j) of nums[i]*nums[k]*nums[j] + dp[i][k] + dp[k][j]

Splitting on FIRST balloon fails because subproblems become coupled.
"""
from __future__ import annotations
import random
from functools import lru_cache
from itertools import permutations
from typing import List


def max_coins_dp(nums: List[int]) -> int:
    # Sentinel padding: virtual 1s at both ends unify boundary cases.
    a = [1] + nums + [1]
    n = len(a)
    dp = [[0] * n for _ in range(n)]
    # INVARIANT: fill by increasing interval length so dp[i][k] and dp[k][j] are ready.
    for length in range(2, n):  # j - i; we need j > i+1 to have any balloon between
        for i in range(n - length):
            j = i + length
            best = 0
            for k in range(i + 1, j):
                coins = a[i] * a[k] * a[j] + dp[i][k] + dp[k][j]
                if coins > best:
                    best = coins
            dp[i][j] = best
    return dp[0][n - 1]


def max_coins_memo(nums: List[int]) -> int:
    a = [1] + nums + [1]

    @lru_cache(maxsize=None)
    def solve(i: int, j: int) -> int:
        if j - i < 2:
            return 0
        best = 0
        for k in range(i + 1, j):
            coins = a[i] * a[k] * a[j] + solve(i, k) + solve(k, j)
            if coins > best:
                best = coins
        return best

    return solve(0, len(a) - 1)


def max_coins_brute(nums: List[int]) -> int:
    n = len(nums)
    if n == 0:
        return 0
    best = 0
    for perm in permutations(range(n)):
        remaining = list(nums)
        # alive[i] = True iff original-index i still present
        alive = [True] * n
        total = 0
        for idx in perm:
            # find left neighbor value
            left = 1
            for li in range(idx - 1, -1, -1):
                if alive[li]:
                    left = remaining[li]
                    break
            right = 1
            for ri in range(idx + 1, n):
                if alive[ri]:
                    right = remaining[ri]
                    break
            total += left * remaining[idx] * right
            alive[idx] = False
        if total > best:
            best = total
    return best


def edge_cases() -> None:
    assert max_coins_dp([3, 1, 5, 8]) == 167
    assert max_coins_dp([1, 5]) == 10
    assert max_coins_dp([1]) == 1
    assert max_coins_dp([]) == 0
    assert max_coins_dp([0, 0, 0]) == 0
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(200):
        n = rng.randint(0, 6)
        nums = [rng.randint(0, 9) for _ in range(n)]
        a = max_coins_dp(nums)
        b = max_coins_memo(nums)
        c = max_coins_brute(nums)
        assert a == b == c, f"trial {trial}: nums={nums} dp={a} memo={b} brute={c}"
    print("stress_test: 200 trials — iterative interval DP, memoized DP, brute permutation all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
