"""
p61 — Partition Equal Subset Sum (LeetCode 416, MEDIUM).

Reduce to: subset-sum == total/2. Classic 0/1 knapsack.
"""
from __future__ import annotations
import random
from typing import List


def can_partition_dp(nums: List[int]) -> bool:
    total = sum(nums)
    if total % 2:
        return False
    target = total // 2
    dp = [False] * (target + 1)
    dp[0] = True
    for x in nums:
        # INVARIANT: iterate s downwards so dp[s-x] still reflects "x not used yet."
        for s in range(target, x - 1, -1):
            if dp[s - x]:
                dp[s] = True
    return dp[target]


def can_partition_bitmask(nums: List[int]) -> bool:
    total = sum(nums)
    if total % 2:
        return False
    target = total // 2
    dp = 1   # bit 0 set: empty subset sums to 0
    for x in nums:
        dp |= dp << x
    return ((dp >> target) & 1) == 1


def can_partition_brute(nums: List[int]) -> bool:
    total = sum(nums)
    if total % 2:
        return False
    target = total // 2
    n = len(nums)
    for mask in range(1 << n):
        s = sum(nums[i] for i in range(n) if mask & (1 << i))
        if s == target:
            return True
    return False


def edge_cases() -> None:
    assert can_partition_dp([1, 5, 11, 5]) is True
    assert can_partition_dp([1, 2, 3, 5]) is False
    assert can_partition_dp([1, 1]) is True
    assert can_partition_dp([1]) is False
    assert can_partition_dp([2, 2, 2, 2]) is True
    assert can_partition_dp([100, 100]) is True
    assert can_partition_dp([3, 3, 3, 4, 5]) is True   # [3,4,5] vs [3,3,3+ ... wait] -> [4,5] vs [3,3,3]? sum=18; 9 each; [4,5]+[3,3,3] => 9+9 yes
    assert can_partition_bitmask([1, 5, 11, 5]) is True
    assert can_partition_bitmask([1, 2, 3, 5]) is False
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(500):
        n = rng.randint(1, 10)
        nums = [rng.randint(1, 8) for _ in range(n)]
        a = can_partition_dp(nums)
        b = can_partition_bitmask(nums)
        c = can_partition_brute(nums)
        assert a == b == c, f"trial {trial}: nums={nums} dp={a} bitmask={b} brute={c}"
    print("stress_test: 500 trials — dp, bitmask, brute agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
