"""
p96 — Basic Calculator II (LC 227)

INVARIANT (deferred-op stack): at any point during the scan, the stack holds
  additively-combinable terms. prev_op records the operator that should be
  applied to cur_num when we finalize it. Higher-precedence ops (* /) reduce
  immediately against the stack top; lower-precedence (+ -) push.

INVARIANT (truncation): integer division truncates toward zero. Use int(a/b)
  to match this in Python where // rounds toward -inf.
"""
from __future__ import annotations

import ast
import operator
import random
from typing import List


def calculate(s: str) -> int:
    """Single-pass deferred-operator stack. O(n) time, O(n) space."""
    stack: List[int] = []
    cur_num = 0
    prev_op = "+"
    for i, c in enumerate(s):
        if c.isdigit():
            cur_num = cur_num * 10 + int(c)
        # Apply prev_op when we hit a non-digit, non-space char OR end of string.
        if (not c.isdigit() and not c.isspace()) or i == len(s) - 1:
            if prev_op == "+":
                stack.append(cur_num)
            elif prev_op == "-":
                stack.append(-cur_num)
            elif prev_op == "*":
                stack.append(stack.pop() * cur_num)
            elif prev_op == "/":
                top = stack.pop()
                # Truncate toward zero.
                stack.append(int(top / cur_num))
            prev_op = c
            cur_num = 0
    return sum(stack)


_PRECEDENCE = {"+": 1, "-": 1, "*": 2, "/": 2}
_APPLY = {
    "+": operator.add,
    "-": operator.sub,
    "*": operator.mul,
    "/": lambda a, b: int(a / b),
}


def _tokenize(s: str) -> List[str]:
    tokens: List[str] = []
    i = 0
    while i < len(s):
        if s[i].isspace():
            i += 1
        elif s[i].isdigit():
            j = i
            while j < len(s) and s[j].isdigit():
                j += 1
            tokens.append(s[i:j])
            i = j
        else:
            tokens.append(s[i])
            i += 1
    return tokens


def calculate_shunting_yard(s: str) -> int:
    """Tokenize -> infix to postfix (shunting-yard) -> evaluate. O(n)."""
    tokens = _tokenize(s)
    output: List[str] = []
    ops: List[str] = []
    for tok in tokens:
        if tok.isdigit():
            output.append(tok)
        else:
            while ops and _PRECEDENCE[ops[-1]] >= _PRECEDENCE[tok]:
                output.append(ops.pop())
            ops.append(tok)
    while ops:
        output.append(ops.pop())

    eval_stack: List[int] = []
    for tok in output:
        if tok.isdigit():
            eval_stack.append(int(tok))
        else:
            b = eval_stack.pop()
            a = eval_stack.pop()
            eval_stack.append(_APPLY[tok](a, b))
    return eval_stack[0]


def calculate_brute(s: str) -> int:
    """Oracle: use Python's AST with custom truncating division."""

    tree = ast.parse(s.strip(), mode="eval")

    def ev(node: ast.AST) -> int:
        if isinstance(node, ast.Expression):
            return ev(node.body)
        if isinstance(node, ast.Constant):
            return node.value
        if isinstance(node, ast.BinOp):
            a, b = ev(node.left), ev(node.right)
            if isinstance(node.op, ast.Add):
                return a + b
            if isinstance(node.op, ast.Sub):
                return a - b
            if isinstance(node.op, ast.Mult):
                return a * b
            if isinstance(node.op, ast.Div):
                return int(a / b)
        raise ValueError(f"unsupported: {ast.dump(node)}")

    return ev(tree)


def edge_cases() -> None:
    assert calculate("3+2*2") == 7
    assert calculate(" 3/2 ") == 1
    assert calculate(" 3+5 / 2 ") == 5
    assert calculate("14-3/2") == 13
    assert calculate("0") == 0
    assert calculate("42") == 42
    assert calculate("1*2*3*4") == 24
    assert calculate("100/10/2") == 5

    # Same for shunting-yard.
    assert calculate_shunting_yard("3+2*2") == 7
    assert calculate_shunting_yard(" 3/2 ") == 1
    assert calculate_shunting_yard("14-3/2") == 13


def stress_test() -> None:
    rng = random.Random(42)
    ops_list = ["+", "-", "*", "/"]
    for _ in range(500):
        n_terms = rng.randint(1, 6)
        parts: List[str] = [str(rng.randint(1, 9))]  # avoid div-by-zero
        for _ in range(n_terms - 1):
            op = rng.choice(ops_list)
            v = rng.randint(1, 9)
            sep = rng.choice(["", " ", "  "])
            parts.append(f"{sep}{op}{sep}{v}")
        expr = "".join(parts)
        # Random leading/trailing spaces.
        expr = " " * rng.randint(0, 2) + expr + " " * rng.randint(0, 2)
        a = calculate(expr)
        b = calculate_shunting_yard(expr)
        c = calculate_brute(expr)
        assert a == b == c, (expr, a, b, c)


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
