"""
p32 — Range Sum Query 2D Immutable (LeetCode 304).

Class NumMatrix:
    __init__(matrix)              — precompute (R+1)x(C+1) 2D prefix sum in O(R*C).
    sumRegion(r1, c1, r2, c2)     — O(1) via inclusion-exclusion.
    sumRegion_brute(r1, c1, r2, c2) — O(area) double loop. ORACLE.

Run: python3 solution.py
"""
from __future__ import annotations
import random
from typing import List


class NumMatrix:
    def __init__(self, matrix: List[List[int]]):
        self._raw = matrix  # kept for the brute oracle only
        if not matrix or not matrix[0]:
            self.R = 0
            self.C = 0
            self.P = [[0]]
            return
        R = len(matrix)
        C = len(matrix[0])
        self.R, self.C = R, C
        # INVARIANT: P[r][c] = sum of matrix[i][j] for 0 <= i < r and 0 <= j < c.
        # Border row 0 and column 0 stay zero — eliminates edge branches in queries.
        P = [[0] * (C + 1) for _ in range(R + 1)]
        for r in range(1, R + 1):
            row_running = 0
            for c in range(1, C + 1):
                row_running += matrix[r - 1][c - 1]
                P[r][c] = P[r - 1][c] + row_running
        self.P = P

    def sumRegion(self, r1: int, c1: int, r2: int, c2: int) -> int:
        # Inclusion-exclusion on the (R+1)x(C+1) prefix:
        #   full - top - left + topleft
        P = self.P
        return P[r2 + 1][c2 + 1] - P[r1][c2 + 1] - P[r2 + 1][c1] + P[r1][c1]

    def sumRegion_brute(self, r1: int, c1: int, r2: int, c2: int) -> int:
        total = 0
        for r in range(r1, r2 + 1):
            for c in range(c1, c2 + 1):
                total += self._raw[r][c]
        return total


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    # LC canonical
    m = [
        [3, 0, 1, 4, 2],
        [5, 6, 3, 2, 1],
        [1, 2, 0, 1, 5],
        [4, 1, 0, 1, 7],
        [1, 0, 3, 0, 5],
    ]
    nm = NumMatrix(m)
    assert nm.sumRegion(2, 1, 4, 3) == 8
    assert nm.sumRegion(1, 1, 2, 2) == 11
    assert nm.sumRegion(1, 2, 2, 4) == 12

    # 1x1
    nm = NumMatrix([[7]])
    assert nm.sumRegion(0, 0, 0, 0) == 7

    # Single row
    nm = NumMatrix([[1, 2, 3, 4]])
    assert nm.sumRegion(0, 0, 0, 3) == 10
    assert nm.sumRegion(0, 1, 0, 2) == 5

    # Single column
    nm = NumMatrix([[1], [2], [3], [4]])
    assert nm.sumRegion(0, 0, 3, 0) == 10
    assert nm.sumRegion(1, 0, 2, 0) == 5

    # Negative values
    nm = NumMatrix([[-1, -2], [-3, -4]])
    assert nm.sumRegion(0, 0, 1, 1) == -10
    assert nm.sumRegion(0, 0, 0, 0) == -1

    # All-zero
    nm = NumMatrix([[0, 0], [0, 0]])
    assert nm.sumRegion(0, 0, 1, 1) == 0

    # Mixed
    nm = NumMatrix([[1, -1], [-1, 1]])
    assert nm.sumRegion(0, 0, 1, 1) == 0
    assert nm.sumRegion(0, 0, 0, 0) == 1
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(200):
        R = rng.randint(1, 8)
        C = rng.randint(1, 8)
        matrix = [[rng.randint(-10, 10) for _ in range(C)] for _ in range(R)]
        nm = NumMatrix(matrix)
        for _q in range(30):
            r1 = rng.randint(0, R - 1)
            r2 = rng.randint(r1, R - 1)
            c1 = rng.randint(0, C - 1)
            c2 = rng.randint(c1, C - 1)
            a = nm.sumRegion(r1, c1, r2, c2)
            b = nm.sumRegion_brute(r1, c1, r2, c2)
            assert a == b, f"DISAGREE matrix={matrix} ({r1},{c1})-({r2},{c2}): prefix={a} brute={b}"
    print("stress_test: 200 matrices × 30 queries each — prefix and brute agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
