"""
p31 — Subarray Sum Equals K (LeetCode 560).

Two implementations:
    subarray_sum        — canonical prefix-sum + hashmap. O(N).
    subarray_sum_brute  — O(N^2) double loop. ORACLE.

Run: python3 solution.py
"""
from __future__ import annotations
import random
from collections import defaultdict
from typing import List


# --------------------------------------------------------------------------------------
# Canonical: prefix sum + hashmap. O(N) time, O(N) space.
# --------------------------------------------------------------------------------------
def subarray_sum(nums: List[int], k: int) -> int:
    # INVARIANT: count[v] = number of prefix sums equal to v seen STRICTLY BEFORE
    #            the current index. Seeding count[0]=1 represents the empty prefix
    #            (the "left wall" at index -1 with running sum 0) — this enables
    #            subarrays starting at index 0 to be counted.
    count: dict[int, int] = defaultdict(int)
    count[0] = 1
    running = 0
    answer = 0
    for x in nums:
        running += x
        # Look up FIRST (so we don't count the current index against itself).
        answer += count[running - k]
        # Then record current prefix as a future left boundary.
        count[running] += 1
    return answer


# --------------------------------------------------------------------------------------
# Brute oracle: O(N^2) double loop with running inner sum.
# --------------------------------------------------------------------------------------
def subarray_sum_brute(nums: List[int], k: int) -> int:
    n = len(nums)
    answer = 0
    for i in range(n):
        s = 0
        for j in range(i, n):
            s += nums[j]
            if s == k:
                answer += 1
    return answer


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    cases = [
        ([], 0, 0),
        ([1], 1, 1),
        ([1], 0, 0),
        ([0], 0, 1),
        ([1, 1, 1], 2, 2),                        # LC canonical
        ([1, 2, 3], 3, 2),                        # [3] and [1,2]
        ([0, 0, 0], 0, 6),                        # C(3,2) + 3 = 6
        ([1, -1, 0], 0, 3),                       # [1,-1], [0], [1,-1,0]
        ([1, -1, 1, -1, 1, -1], 0, 9),            # interleaved
        ([1, 2, 1, 2, 1], 3, 4),                  # [1,2]x2, [2,1]x2
        ([-1, -1, 1], 0, 1),                      # [-1,1]
        ([100, 100, 100], 200, 2),                # [100,100]x2
        ([-1000] * 10, 0, 0),                     # never reaches 0
        ([1, -1] * 10, 0, 10 * (10 + 1) // 2 * 2 // 2),  # placeholder; brute will confirm
    ]
    for nums, k, expected in cases:
        # For the tricky cases, derive expected via brute rather than trusting the comment.
        truth = subarray_sum_brute(nums, k)
        got = subarray_sum(nums, k)
        assert got == truth, f"subarray_sum({nums}, {k}) → {got}, brute says {truth}"
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(300):
        n = rng.randint(0, 25)
        nums = [rng.randint(-5, 5) for _ in range(n)]
        k = rng.randint(-10, 10)
        a = subarray_sum(nums, k)
        b = subarray_sum_brute(nums, k)
        assert a == b, f"DISAGREE nums={nums} k={k}: hashmap={a} brute={b}"
    print("stress_test: 300 random arrays (with negatives) — hashmap and brute agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
