"""
p35 — Median of Two Sorted Arrays (LeetCode 4, HARD).

Two implementations:
    find_median_sorted_arrays  — partition binary search. O(log(min(m, n))).
    find_median_brute          — two-pointer merge. O(m + n). ORACLE.

Run: python3 solution.py
"""
from __future__ import annotations
import math
import random
from typing import List


# --------------------------------------------------------------------------------------
# Canonical: binary search on partition of the smaller array. O(log(min(m, n))).
# --------------------------------------------------------------------------------------
def find_median_sorted_arrays(nums1: List[int], nums2: List[int]) -> float:
    A, B = nums1, nums2
    # Always binary-search the SMALLER array so the search range is
    # [0, len(A)] and j = (T+1)//2 - i stays in [0, len(B)].
    if len(A) > len(B):
        A, B = B, A
    m, n = len(A), len(B)
    T = m + n
    half = (T + 1) // 2  # size of the combined left side (odd T: median is the max-left)

    lo, hi = 0, m
    INF = math.inf
    # INVARIANT: the correct partition has A's left-side size `i` in [lo, hi].
    # Convergent over [0, m] until we find i with:
    #   maxLeftA <= minRightB  AND  maxLeftB <= minRightA
    while lo <= hi:
        i = (lo + hi) // 2
        j = half - i

        maxLeftA = A[i - 1] if i > 0 else -INF
        minRightA = A[i] if i < m else INF
        maxLeftB = B[j - 1] if j > 0 else -INF
        minRightB = B[j] if j < n else INF

        if maxLeftA <= minRightB and maxLeftB <= minRightA:
            # Found valid partition.
            if T % 2 == 1:
                return float(max(maxLeftA, maxLeftB))
            return (max(maxLeftA, maxLeftB) + min(minRightA, minRightB)) / 2.0
        elif maxLeftA > minRightB:
            # Too many A elements on the left → shrink i.
            hi = i - 1
        else:
            # Too few A elements on the left → grow i.
            lo = i + 1
    # Unreachable given valid inputs (m, n >= 0, not both zero).
    raise RuntimeError("invalid input: arrays must contain at least one element combined")


# --------------------------------------------------------------------------------------
# Brute oracle: merge with two pointers, stop at median index. O(m + n).
# --------------------------------------------------------------------------------------
def find_median_brute(nums1: List[int], nums2: List[int]) -> float:
    merged: List[int] = []
    i = j = 0
    while i < len(nums1) and j < len(nums2):
        if nums1[i] <= nums2[j]:
            merged.append(nums1[i])
            i += 1
        else:
            merged.append(nums2[j])
            j += 1
    merged.extend(nums1[i:])
    merged.extend(nums2[j:])
    T = len(merged)
    if T % 2 == 1:
        return float(merged[T // 2])
    return (merged[T // 2 - 1] + merged[T // 2]) / 2.0


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def edge_cases() -> None:
    # LC canonical
    assert find_median_sorted_arrays([1, 3], [2]) == 2.0
    assert find_median_sorted_arrays([1, 2], [3, 4]) == 2.5

    # One empty
    assert find_median_sorted_arrays([], [1]) == 1.0
    assert find_median_sorted_arrays([2], []) == 2.0
    assert find_median_sorted_arrays([], [1, 2]) == 1.5
    assert find_median_sorted_arrays([1, 2, 3, 4], []) == 2.5

    # Single elements
    assert find_median_sorted_arrays([1], [2]) == 1.5
    assert find_median_sorted_arrays([5], [5]) == 5.0

    # No overlap
    assert find_median_sorted_arrays([1, 2, 3], [4, 5, 6]) == 3.5
    assert find_median_sorted_arrays([4, 5, 6], [1, 2, 3]) == 3.5  # swapped

    # All equal
    assert find_median_sorted_arrays([1, 1, 1], [1, 1, 1]) == 1.0
    assert find_median_sorted_arrays([2, 2], [2, 2, 2]) == 2.0

    # Interleaved
    assert find_median_sorted_arrays([1, 3, 5, 7], [2, 4, 6, 8]) == 4.5

    # Negatives
    assert find_median_sorted_arrays([-5, -3, -1], [-4, -2, 0]) == -2.5

    # Skewed sizes
    assert find_median_sorted_arrays([1], [2, 3, 4, 5, 6, 7]) == 4.0

    # Large vs single
    assert find_median_sorted_arrays([1, 2, 3, 4, 5, 6, 7, 8, 9], [10]) == 5.5

    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(200):
        m = rng.randint(0, 15)
        n = rng.randint(0, 15)
        if m + n == 0:
            n = 1  # ensure non-empty union
        a = sorted(rng.randint(-30, 30) for _ in range(m))
        b = sorted(rng.randint(-30, 30) for _ in range(n))
        got = find_median_sorted_arrays(a, b)
        truth = find_median_brute(a, b)
        assert math.isclose(got, truth), f"DISAGREE trial={trial} a={a} b={b}: partition={got} merge={truth}"
    print("stress_test: 200 random sorted-array pairs — partition and merge agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
