"""
p38 — Subsets (LeetCode 78, MEDIUM).

Five implementations:
    subsets_backtrack       — start-index, snapshot at every node.
    subsets_include_exclude — binary choice tree (depth = n).
    subsets_iterative       — powerset growth, no recursion.
    subsets_bitmask         — enumerate 0..2^n - 1.
    subsets_with_dup        — LC 90 (duplicates), sort + skip rule.
    subsets_brute           — oracle via itertools.

Run: python3 solution.py
"""
from __future__ import annotations
import itertools
import random
from typing import List


# --------------------------------------------------------------------------------------
# Variant 1: backtracking, start index, snapshot at every node.
# --------------------------------------------------------------------------------------
def subsets_backtrack(nums: List[int]) -> List[List[int]]:
    n = len(nums)
    result: List[List[int]] = []
    path: List[int] = []

    def backtrack(start: int) -> None:
        # INVARIANT: every recursive entry corresponds to ONE valid subset (= path so far).
        # Snapshot here — not gated by length.
        result.append(path[:])
        for i in range(start, n):
            path.append(nums[i])
            backtrack(i + 1)
            path.pop()

    backtrack(0)
    return result


# --------------------------------------------------------------------------------------
# Variant 2: include/exclude — binary choice tree (depth n, branching 2).
# --------------------------------------------------------------------------------------
def subsets_include_exclude(nums: List[int]) -> List[List[int]]:
    n = len(nums)
    result: List[List[int]] = []
    path: List[int] = []

    def backtrack(i: int) -> None:
        if i == n:
            result.append(path[:])
            return
        # exclude
        backtrack(i + 1)
        # include
        path.append(nums[i])
        backtrack(i + 1)
        path.pop()

    backtrack(0)
    return result


# --------------------------------------------------------------------------------------
# Variant 3: iterative powerset growth. No recursion.
# --------------------------------------------------------------------------------------
def subsets_iterative(nums: List[int]) -> List[List[int]]:
    result: List[List[int]] = [[]]
    for x in nums:
        result = result + [sub + [x] for sub in result]
    return result


# --------------------------------------------------------------------------------------
# Variant 4: bitmask. n ≤ 32 in CPython (no big issue at n ≤ 64 either).
# --------------------------------------------------------------------------------------
def subsets_bitmask(nums: List[int]) -> List[List[int]]:
    n = len(nums)
    result: List[List[int]] = []
    for mask in range(1 << n):
        result.append([nums[i] for i in range(n) if (mask >> i) & 1])
    return result


# --------------------------------------------------------------------------------------
# LC 90: Subsets with duplicates. Sort + dedup at same level.
# --------------------------------------------------------------------------------------
def subsets_with_dup(nums: List[int]) -> List[List[int]]:
    nums = sorted(nums)
    n = len(nums)
    result: List[List[int]] = []
    path: List[int] = []

    def backtrack(start: int) -> None:
        result.append(path[:])
        for i in range(start, n):
            # Canonical-form pruning: at the same recursion level (same `start`),
            # take the FIRST of any equal-value siblings; skip the rest.
            if i > start and nums[i] == nums[i - 1]:
                continue
            path.append(nums[i])
            backtrack(i + 1)
            path.pop()

    backtrack(0)
    return result


# --------------------------------------------------------------------------------------
# Oracle.
# --------------------------------------------------------------------------------------
def subsets_brute(nums: List[int]) -> List[List[int]]:
    result: List[List[int]] = []
    for k in range(len(nums) + 1):
        for combo in itertools.combinations(nums, k):
            result.append(list(combo))
    return result


def _canon(subs: List[List[int]]) -> List[tuple]:
    return sorted(tuple(sorted(s)) for s in subs)


def _canon_unique(subs: List[List[int]]) -> List[tuple]:
    return sorted(set(tuple(sorted(s)) for s in subs))


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    # Empty
    assert subsets_backtrack([]) == [[]]
    assert subsets_include_exclude([]) == [[]]
    assert subsets_iterative([]) == [[]]
    assert subsets_bitmask([]) == [[]]

    # Single
    for fn in (subsets_backtrack, subsets_include_exclude, subsets_iterative, subsets_bitmask):
        assert _canon(fn([7])) == _canon([[], [7]])

    # Three distinct: should produce 8 subsets each
    for fn in (subsets_backtrack, subsets_include_exclude, subsets_iterative, subsets_bitmask):
        assert len(fn([1, 2, 3])) == 8
        assert _canon(fn([1, 2, 3])) == _canon(subsets_brute([1, 2, 3]))

    # n = 4 → 16 subsets
    for fn in (subsets_backtrack, subsets_include_exclude, subsets_iterative, subsets_bitmask):
        assert len(fn([1, 2, 3, 4])) == 16

    # LC 90: [1,2,2] should have unique subsets [[],[1],[2],[1,2],[2,2],[1,2,2]] = 6
    assert len(subsets_with_dup([1, 2, 2])) == 6
    expected_122 = [[], [1], [2], [1, 2], [2, 2], [1, 2, 2]]
    assert _canon(subsets_with_dup([1, 2, 2])) == _canon(expected_122)

    # LC 90: all same — [3,3,3] → [[],[3],[3,3],[3,3,3]] = 4
    assert len(subsets_with_dup([3, 3, 3])) == 4

    # LC 90: distinct collapses to LC 78
    assert _canon(subsets_with_dup([1, 2, 3])) == _canon(subsets_backtrack([1, 2, 3]))

    # No internal duplicates in LC 90 output
    got = subsets_with_dup([1, 2, 2, 3, 3])
    assert len(got) == len(set(tuple(sorted(s)) for s in got))

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    # Distinct: all four variants vs oracle
    for _ in range(100):
        n = rng.randint(0, 6)
        nums = rng.sample(range(-15, 15), n)
        oracle = _canon(subsets_brute(nums))
        for fn in (subsets_backtrack, subsets_include_exclude, subsets_iterative, subsets_bitmask):
            assert _canon(fn(nums)) == oracle, f"{fn.__name__} disagrees on {nums}"

    # Duplicates: subsets_with_dup vs deduped oracle
    for _ in range(100):
        n = rng.randint(0, 6)
        nums = [rng.randint(0, 3) for _ in range(n)]
        oracle = _canon_unique(subsets_brute(nums))
        assert _canon_unique(subsets_with_dup(nums)) == oracle, f"subsets_with_dup disagrees on {nums}"

    print("stress_test: 200 random inputs (100 distinct, 100 duplicates) — all variants agree with oracle.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
