"""
p91 — Range Sum Query Mutable (LC 307)

INVARIANT (BIT): tree[i] stores arr[i - lowbit(i) + 1 .. i] where
  lowbit(i) = i & -i. update walks i += lowbit(i); prefix walks i -= lowbit(i).

INVARIANT (iter segment tree): leaves at [n, 2n). Internal node i covers the
  union of children 2i and 2i+1.
"""
from __future__ import annotations

import random
from typing import List


class NumArrayBIT:
    """Fenwick Tree (BIT). 1-indexed internally."""

    def __init__(self, nums: List[int]) -> None:
        self.n = len(nums)
        self.arr = [0] * self.n  # cache for translating update() to delta
        self.tree = [0] * (self.n + 1)
        for i, v in enumerate(nums):
            self.update(i, v)

    def update(self, index: int, val: int) -> None:
        delta = val - self.arr[index]
        self.arr[index] = val
        i = index + 1
        while i <= self.n:
            self.tree[i] += delta
            i += i & -i

    def _prefix(self, i: int) -> int:
        """Sum of arr[0..i-1] (1-indexed up to i)."""
        s = 0
        while i > 0:
            s += self.tree[i]
            i -= i & -i
        return s

    def sumRange(self, left: int, right: int) -> int:
        return self._prefix(right + 1) - self._prefix(left)


class NumArraySegTree:
    """Iterative segment tree over 2n nodes."""

    def __init__(self, nums: List[int]) -> None:
        self.n = len(nums)
        self.tree = [0] * (2 * self.n)
        # Build: copy leaves, then aggregate upward.
        for i in range(self.n):
            self.tree[self.n + i] = nums[i]
        for i in range(self.n - 1, 0, -1):
            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]

    def update(self, index: int, val: int) -> None:
        i = index + self.n
        self.tree[i] = val
        i //= 2
        while i > 0:
            self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]
            i //= 2

    def sumRange(self, left: int, right: int) -> int:
        l = left + self.n
        r = right + self.n + 1  # half-open on the right
        s = 0
        while l < r:
            if l & 1:
                s += self.tree[l]
                l += 1
            if r & 1:
                r -= 1
                s += self.tree[r]
            l //= 2
            r //= 2
        return s


class NumArrayBrute:
    def __init__(self, nums: List[int]) -> None:
        self.arr = list(nums)

    def update(self, index: int, val: int) -> None:
        self.arr[index] = val

    def sumRange(self, left: int, right: int) -> int:
        return sum(self.arr[left : right + 1])


def edge_cases() -> None:
    nums = [1, 3, 5]
    for Cls in (NumArrayBIT, NumArraySegTree):
        na = Cls(nums)
        assert na.sumRange(0, 2) == 9
        na.update(1, 2)
        assert na.sumRange(0, 2) == 8
        assert na.sumRange(1, 1) == 2

    # Single element.
    for Cls in (NumArrayBIT, NumArraySegTree):
        na = Cls([7])
        assert na.sumRange(0, 0) == 7
        na.update(0, -3)
        assert na.sumRange(0, 0) == -3


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(200):
        n = rng.randint(1, 12)
        nums = [rng.randint(-10, 10) for _ in range(n)]
        bit = NumArrayBIT(nums[:])
        seg = NumArraySegTree(nums[:])
        brute = NumArrayBrute(nums[:])
        for _ in range(rng.randint(5, 30)):
            op = rng.choice(["u", "q"])
            if op == "u":
                idx = rng.randint(0, n - 1)
                val = rng.randint(-10, 10)
                bit.update(idx, val)
                seg.update(idx, val)
                brute.update(idx, val)
            else:
                l = rng.randint(0, n - 1)
                r = rng.randint(l, n - 1)
                b = brute.sumRange(l, r)
                assert bit.sumRange(l, r) == b
                assert seg.sumRange(l, r) == b


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
