"""
p85 — Split Array Largest Sum (LC 410)

INVARIANT (monotonicity): feasible(S) = "can partition nums into <= k contiguous
  parts each with sum <= S" is monotone non-decreasing in S. Binary-search the
  smallest feasible S.

INVARIANT (greedy feasibility): first-fit left-to-right is optimal — closing a
  part earlier than necessary never helps the suffix.
"""
from __future__ import annotations

import random
from itertools import combinations
from typing import List


def split_array_bsa(nums: List[int], k: int) -> int:
    """O(n log(sum)) binary search on answer + greedy feasibility."""

    def parts_needed(cap: int) -> int:
        parts = 1
        cur = 0
        for x in nums:
            if cur + x > cap:
                parts += 1
                cur = x
            else:
                cur += x
        return parts

    lo, hi = max(nums), sum(nums)
    while lo < hi:
        mid = (lo + hi) // 2
        if parts_needed(mid) <= k:
            hi = mid  # feasible; try smaller
        else:
            lo = mid + 1
    return lo


def split_array_dp(nums: List[int], k: int) -> int:
    """O(n^2 * k) DP. dp[i][j] = min largest-sum splitting nums[:i] into j parts."""
    n = len(nums)
    prefix = [0] * (n + 1)
    for i in range(n):
        prefix[i + 1] = prefix[i] + nums[i]

    INF = float("inf")
    # dp[i][j] for i in 0..n, j in 0..k
    dp = [[INF] * (k + 1) for _ in range(n + 1)]
    dp[0][0] = 0
    for i in range(1, n + 1):
        for j in range(1, min(i, k) + 1):
            # Last part is nums[t:i] for some t in [j-1, i-1].
            for t in range(j - 1, i):
                last_sum = prefix[i] - prefix[t]
                dp[i][j] = min(dp[i][j], max(dp[t][j - 1], last_sum))
    return int(dp[n][k])


def split_array_brute(nums: List[int], k: int) -> int:
    """Oracle: enumerate all C(n-1, k-1) split point sets."""
    n = len(nums)
    if k == 1:
        return sum(nums)
    if k >= n:
        return max(nums) if nums else 0
    best = float("inf")
    # Choose k-1 split positions from indices 1..n-1.
    for cuts in combinations(range(1, n), k - 1):
        parts_max = 0
        prev = 0
        for c in list(cuts) + [n]:
            s = sum(nums[prev:c])
            if s > parts_max:
                parts_max = s
            prev = c
        if parts_max < best:
            best = parts_max
    return int(best)


def edge_cases() -> None:
    assert split_array_bsa([7, 2, 5, 10, 8], 2) == 18
    assert split_array_dp([7, 2, 5, 10, 8], 2) == 18

    # k = 1 → whole sum.
    assert split_array_bsa([1, 2, 3, 4], 1) == 10
    assert split_array_dp([1, 2, 3, 4], 1) == 10

    # k = n → max single element.
    assert split_array_bsa([1, 2, 3, 4], 4) == 4
    assert split_array_dp([1, 2, 3, 4], 4) == 4

    # All zeros.
    assert split_array_bsa([0, 0, 0, 0], 2) == 0

    # Single element.
    assert split_array_bsa([5], 1) == 5

    # Element dominates.
    assert split_array_bsa([1, 1, 100, 1, 1], 3) == 100


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(200):
        n = rng.randint(1, 6)
        nums = [rng.randint(0, 10) for _ in range(n)]
        k = rng.randint(1, n)

        bsa = split_array_bsa(nums[:], k)
        dp = split_array_dp(nums[:], k)
        brute = split_array_brute(nums[:], k)
        assert bsa == dp == brute, (nums, k, bsa, dp, brute)


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
