"""
LC 15 — 3Sum.

Sort + converging two-pointer, with skip-dedup at all three positions.

Implementations:
    1. three_sum         — canonical. O(N^2) time, O(1) aux (output excluded).
    2. three_sum_hashset — fix i, hash for inner two-sum. Worse constants.
    3. three_sum_brute   — O(N^3). Oracle.
    4. k_sum             — generalized k-Sum (bonus, includes 4Sum etc).

INVARIANT (outer): after `if i > 0 and nums[i] == nums[i-1]: continue`,
                   the value nums[i] is encountered exactly once across the outer loop.
INVARIANT (inner): after recording a triplet AND advancing lo/hi, the skip loops
                   ensure nums[lo] differs from the just-recorded nums[lo-1].
INVARIANT (sort): nums is sorted ascending after `nums.sort()`.
"""

from __future__ import annotations

import random


# ----------------------------------------------------------------------
# Solution A — Canonical sort + 2-pointer.
# ----------------------------------------------------------------------

def three_sum(nums: list[int]) -> list[list[int]]:
    nums = sorted(nums)  # don't mutate caller's list
    n = len(nums)
    out: list[list[int]] = []
    for i in range(n - 2):
        # Outer dedup: skip equal-value runs for i.
        if i > 0 and nums[i] == nums[i - 1]:
            continue
        # Pruning: if nums[i] > 0, no triplet (since sorted, all later are >=).
        if nums[i] > 0:
            break
        target = -nums[i]
        lo, hi = i + 1, n - 1
        while lo < hi:
            s = nums[lo] + nums[hi]
            if s == target:
                out.append([nums[i], nums[lo], nums[hi]])
                lo += 1
                hi -= 1
                # Inner-left dedup.
                while lo < hi and nums[lo] == nums[lo - 1]:
                    lo += 1
                # Inner-right dedup.
                while lo < hi and nums[hi] == nums[hi + 1]:
                    hi -= 1
            elif s < target:
                lo += 1
            else:
                hi -= 1
    return out


# ----------------------------------------------------------------------
# Solution B — Hash-set inner two-sum.
# ----------------------------------------------------------------------

def three_sum_hashset(nums: list[int]) -> list[list[int]]:
    nums = sorted(nums)
    n = len(nums)
    out: list[list[int]] = []
    for i in range(n - 2):
        if i > 0 and nums[i] == nums[i - 1]:
            continue
        if nums[i] > 0:
            break
        target = -nums[i]
        seen: set[int] = set()
        j = i + 1
        last_added: int | None = None
        while j < n:
            complement = target - nums[j]
            if complement in seen and nums[j] != last_added:
                out.append([nums[i], complement, nums[j]])
                last_added = nums[j]
            seen.add(nums[j])
            j += 1
    return out


# ----------------------------------------------------------------------
# Solution C — Brute O(N^3). Oracle.
# ----------------------------------------------------------------------

def three_sum_brute(nums: list[int]) -> list[list[int]]:
    n = len(nums)
    triplets: set[tuple[int, int, int]] = set()
    for i in range(n):
        for j in range(i + 1, n):
            for k in range(j + 1, n):
                if nums[i] + nums[j] + nums[k] == 0:
                    triplets.add(tuple(sorted((nums[i], nums[j], nums[k]))))
    return [list(t) for t in triplets]


# ----------------------------------------------------------------------
# Bonus — Generic k-Sum.
# ----------------------------------------------------------------------

def k_sum(nums: list[int], target: int, k: int) -> list[list[int]]:
    """Return all unique k-element combinations summing to target. O(N^(k-1))."""
    nums = sorted(nums)
    return _k_sum(nums, target, k, 0)


def _k_sum(nums: list[int], target: int, k: int, start: int) -> list[list[int]]:
    n = len(nums)
    out: list[list[int]] = []
    if k == 2:
        lo, hi = start, n - 1
        while lo < hi:
            s = nums[lo] + nums[hi]
            if s == target:
                out.append([nums[lo], nums[hi]])
                lo += 1
                hi -= 1
                while lo < hi and nums[lo] == nums[lo - 1]:
                    lo += 1
                while lo < hi and nums[hi] == nums[hi + 1]:
                    hi -= 1
            elif s < target:
                lo += 1
            else:
                hi -= 1
        return out
    # k >= 3
    for i in range(start, n - k + 1):
        if i > start and nums[i] == nums[i - 1]:
            continue
        # Pruning.
        if nums[i] * k > target and target > 0:
            break  # rough; not always safe — kept conservative
        for tail in _k_sum(nums, target - nums[i], k - 1, i + 1):
            out.append([nums[i]] + tail)
    return out


# ----------------------------------------------------------------------
# Helpers.
# ----------------------------------------------------------------------

def canon(triplets: list[list[int]]) -> list[tuple[int, int, int]]:
    return sorted(tuple(sorted(t)) for t in triplets)


# ----------------------------------------------------------------------
# Stress test.
# ----------------------------------------------------------------------

def stress_test() -> None:
    rng = random.Random(42)
    n_iter = 200
    for _ in range(n_iter):
        n = rng.randint(0, 15)
        nums = [rng.randint(-6, 6) for _ in range(n)]
        a = canon(three_sum(nums))
        b = canon(three_sum_hashset(nums))
        c = canon(three_sum_brute(nums))
        d = canon(k_sum(nums, 0, 3))
        assert a == b == c == d, f"Disagreement on nums={nums}\n2p={a}\nhash={b}\nbrute={c}\nksum={d}"
    # 4Sum sanity via k_sum vs brute.
    for _ in range(50):
        n = rng.randint(0, 10)
        nums = [rng.randint(-5, 5) for _ in range(n)]
        target = rng.randint(-8, 8)
        brute4 = set()
        for i in range(n):
            for j in range(i + 1, n):
                for kk in range(j + 1, n):
                    for l in range(kk + 1, n):
                        if nums[i] + nums[j] + nums[kk] + nums[l] == target:
                            brute4.add(tuple(sorted((nums[i], nums[j], nums[kk], nums[l]))))
        got = set(tuple(t) for t in k_sum(nums, target, 4))
        assert got == brute4, f"4Sum mismatch nums={nums} target={target}\ngot={got}\nbrute={brute4}"
    print(f"stress_test: {n_iter} random arrays (3Sum) + 50 (4Sum) — 2-pointer, hash, brute, k_sum all agree.")


# ----------------------------------------------------------------------
# Edge cases.
# ----------------------------------------------------------------------

def edge_cases() -> None:
    # Empty / too small.
    for fn in (three_sum, three_sum_hashset, three_sum_brute):
        assert fn([]) == []
        assert fn([0]) == []
        assert fn([0, 0]) == []

    # All zeros.
    for fn in (three_sum, three_sum_hashset, three_sum_brute):
        assert canon(fn([0, 0, 0])) == [(0, 0, 0)]
        assert canon(fn([0, 0, 0, 0])) == [(0, 0, 0)]
        assert canon(fn([0] * 10)) == [(0, 0, 0)]

    # All positive — no triplet sums to 0.
    for fn in (three_sum, three_sum_hashset, three_sum_brute):
        assert fn([1, 2, 3, 4]) == []

    # All negative — same.
    for fn in (three_sum, three_sum_hashset, three_sum_brute):
        assert fn([-1, -2, -3, -4]) == []

    # LC canonical.
    for fn in (three_sum, three_sum_hashset, three_sum_brute):
        assert canon(fn([-1, 0, 1, 2, -1, -4])) == [(-1, -1, 2), (-1, 0, 1)]

    # Dedup-heavy.
    for fn in (three_sum, three_sum_hashset, three_sum_brute):
        assert canon(fn([-2, 0, 0, 2, 2])) == [(-2, 0, 2)]

    # k_sum delegating to 2-sum sanity.
    assert canon(k_sum([2, 7, 11, 15], 9, 2)) == [(2, 7)]

    print("edge_cases: PASS")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
