"""
p92 — Count of Smaller Numbers After Self (LC 315)

INVARIANT (BIT right-to-left): at step i, BIT contains the multiset of nums[i+1..n-1].
  ans[i] = prefix_count(rank(nums[i]) - 1) = # of seen values strictly smaller.

INVARIANT (merge sort): during merge, when popping from left, the # of right
  elements already popped equals the inversion contribution for that left index.
"""
from __future__ import annotations

import random
from typing import List


class _BIT:
    def __init__(self, n: int) -> None:
        self.n = n
        self.tree = [0] * (n + 1)

    def update(self, i: int, v: int = 1) -> None:
        while i <= self.n:
            self.tree[i] += v
            i += i & -i

    def prefix(self, i: int) -> int:
        s = 0
        while i > 0:
            s += self.tree[i]
            i -= i & -i
        return s


def count_smaller_bit(nums: List[int]) -> List[int]:
    """BIT over coordinate-compressed values, right-to-left. O(n log n)."""
    if not nums:
        return []
    sorted_vals = sorted(set(nums))
    rank = {v: i + 1 for i, v in enumerate(sorted_vals)}  # 1-indexed
    bit = _BIT(len(sorted_vals))
    ans = [0] * len(nums)
    for i in range(len(nums) - 1, -1, -1):
        r = rank[nums[i]]
        ans[i] = bit.prefix(r - 1)
        bit.update(r, 1)
    return ans


def count_smaller_mergesort(nums: List[int]) -> List[int]:
    """Merge sort tracking original indices. O(n log n)."""
    n = len(nums)
    ans = [0] * n
    # Work with (value, original_index) pairs.
    arr = list(enumerate(nums))

    def merge_sort(a: List[tuple[int, int]]) -> List[tuple[int, int]]:
        if len(a) <= 1:
            return a
        mid = len(a) // 2
        left = merge_sort(a[:mid])
        right = merge_sort(a[mid:])
        merged: List[tuple[int, int]] = []
        i = j = 0
        # Count how many from right have been merged when a left element is taken.
        right_popped = 0
        while i < len(left) and j < len(right):
            if left[i][1] <= right[j][1]:
                # left[i] goes next; right_popped right-elements are smaller and to its right.
                ans[left[i][0]] += right_popped
                merged.append(left[i])
                i += 1
            else:
                right_popped += 1
                merged.append(right[j])
                j += 1
        while i < len(left):
            ans[left[i][0]] += right_popped
            merged.append(left[i])
            i += 1
        while j < len(right):
            merged.append(right[j])
            j += 1
        return merged

    merge_sort(arr)
    return ans


def count_smaller_brute(nums: List[int]) -> List[int]:
    """Oracle: O(n^2)."""
    n = len(nums)
    return [sum(1 for j in range(i + 1, n) if nums[j] < nums[i]) for i in range(n)]


def edge_cases() -> None:
    assert count_smaller_bit([5, 2, 6, 1]) == [2, 1, 1, 0]
    assert count_smaller_mergesort([5, 2, 6, 1]) == [2, 1, 1, 0]

    # Empty.
    assert count_smaller_bit([]) == []
    # Single.
    assert count_smaller_bit([7]) == [0]
    # Sorted ascending.
    assert count_smaller_bit([1, 2, 3, 4]) == [0, 0, 0, 0]
    # Sorted descending.
    assert count_smaller_bit([4, 3, 2, 1]) == [3, 2, 1, 0]
    # Duplicates.
    assert count_smaller_bit([2, 2, 2]) == [0, 0, 0]
    assert count_smaller_mergesort([2, 2, 2]) == [0, 0, 0]
    # Negative values.
    assert count_smaller_bit([-1, -2, 0, 3]) == [1, 0, 0, 0]


def stress_test() -> None:
    rng = random.Random(42)
    for _ in range(200):
        n = rng.randint(0, 12)
        nums = [rng.randint(-5, 5) for _ in range(n)]
        bit_ans = count_smaller_bit(nums[:])
        ms_ans = count_smaller_mergesort(nums[:])
        brute = count_smaller_brute(nums[:])
        assert bit_ans == brute == ms_ans, (nums, bit_ans, ms_ans, brute)


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
