"""
p78 — Cherry Pickup (LeetCode 741, HARD).

Reframe: round trip A -> B -> A with right/down forward and left/up backward
equals two simultaneous A -> B walks, both moving right/down.

State: (r1, c1, r2) with c2 = r1 + c1 - r2 (step-count invariant).
Recurrence (max over 4 predecessor combos):
    dp[r1][c1][r2] = max over prev (r1-1 or c1-1 for agent 1; r2-1 or c2-1 for agent 2)
                     of dp[prev] + gain(r1,c1,r2,c2)
gain = grid[r1][c1] if (r1,c1)==(r2,c2) else grid[r1][c1] + grid[r2][c2]
If grid[r1][c1] == -1 or grid[r2][c2] == -1 -> state invalid (-inf).
Answer: max(0, dp[n-1][n-1][n-1]).
"""
from __future__ import annotations
import random
from functools import lru_cache
from itertools import product
from typing import List

NEG_INF = float("-inf")


def cherry_pickup_dp(grid: List[List[int]]) -> int:
    n = len(grid)

    @lru_cache(maxsize=None)
    def solve(r1: int, c1: int, r2: int) -> float:
        c2 = r1 + c1 - r2
        if r1 >= n or c1 >= n or r2 >= n or c2 >= n or c2 < 0 or r2 < 0:
            return NEG_INF
        if grid[r1][c1] == -1 or grid[r2][c2] == -1:
            return NEG_INF
        if r1 == n - 1 and c1 == n - 1:
            return grid[r1][c1]  # both must be at goal; if (r2,c2)==(n-1,n-1) too
        gain = grid[r1][c1] if (r1, c1) == (r2, c2) else grid[r1][c1] + grid[r2][c2]
        best = NEG_INF
        # agent 1 moves: down (r1+1,c1) or right (r1,c1+1); same for agent 2.
        for d1, d2 in product([(1, 0), (0, 1)], repeat=2):
            nr1, nc1 = r1 + d1[0], c1 + d1[1]
            nr2 = r2 + d2[0]
            best = max(best, solve(nr1, nc1, nr2))
        if best == NEG_INF:
            return NEG_INF
        return gain + best

    res = solve(0, 0, 0)
    return max(0, int(res)) if res != NEG_INF else 0


def cherry_pickup_iter(grid: List[List[int]]) -> int:
    # Bottom-up by step count t = r1 + c1 = r2 + c2.
    n = len(grid)
    if grid[0][0] == -1 or grid[n - 1][n - 1] == -1:
        return 0
    # dp[r1][r2] for current step t.
    dp = [[NEG_INF] * n for _ in range(n)]
    dp[0][0] = grid[0][0]  # both start at (0,0); count once
    for t in range(1, 2 * n - 1):
        new_dp = [[NEG_INF] * n for _ in range(n)]
        for r1 in range(max(0, t - (n - 1)), min(n - 1, t) + 1):
            c1 = t - r1
            if grid[r1][c1] == -1:
                continue
            for r2 in range(max(0, t - (n - 1)), min(n - 1, t) + 1):
                c2 = t - r2
                if grid[r2][c2] == -1:
                    continue
                gain = grid[r1][c1] if (r1, c1) == (r2, c2) else grid[r1][c1] + grid[r2][c2]
                best = NEG_INF
                for dr1 in (0, 1):
                    for dr2 in (0, 1):
                        pr1, pr2 = r1 - dr1, r2 - dr2
                        if pr1 < 0 or pr2 < 0:
                            continue
                        # at previous step, c was t-1 - pr; valid if in-range
                        pc1 = (t - 1) - pr1
                        pc2 = (t - 1) - pr2
                        if pc1 < 0 or pc1 >= n or pc2 < 0 or pc2 >= n:
                            continue
                        prev = dp[pr1][pr2]
                        if prev > best:
                            best = prev
                if best == NEG_INF:
                    new_dp[r1][r2] = NEG_INF
                else:
                    new_dp[r1][r2] = best + gain
        dp = new_dp
    res = dp[n - 1][n - 1]
    return max(0, int(res)) if res != NEG_INF else 0


def cherry_pickup_brute(grid: List[List[int]]) -> int:
    # Enumerate all forward paths; for each, simulate consumption; enumerate all backward paths;
    # return max sum.
    n = len(grid)
    if grid[0][0] == -1 or grid[n - 1][n - 1] == -1:
        return 0
    best = 0

    def forward_paths(r: int, c: int, path: List[tuple]):
        if grid[r][c] == -1:
            return
        path.append((r, c))
        if r == n - 1 and c == n - 1:
            yield list(path)
        else:
            if r + 1 < n:
                yield from forward_paths(r + 1, c, path)
            if c + 1 < n:
                yield from forward_paths(r, c + 1, path)
        path.pop()

    def backward_paths(r: int, c: int, path: List[tuple], grid2):
        # walk from (n-1,n-1) to (0,0) moving left/up
        if grid2[r][c] == -1:
            return
        path.append((r, c))
        if r == 0 and c == 0:
            yield list(path)
        else:
            if r - 1 >= 0:
                yield from backward_paths(r - 1, c, path, grid2)
            if c - 1 >= 0:
                yield from backward_paths(r, c - 1, path, grid2)
        path.pop()

    for fp in forward_paths(0, 0, []):
        # consume
        consumed = [row[:] for row in grid]
        s_f = 0
        for (r, c) in fp:
            s_f += consumed[r][c]
            consumed[r][c] = 0
        # enumerate backward
        for bp in backward_paths(n - 1, n - 1, [], consumed):
            s_b = 0
            cells_seen = set()
            for (r, c) in bp:
                if (r, c) not in cells_seen:
                    s_b += consumed[r][c]
                    cells_seen.add((r, c))
            total = s_f + s_b
            if total > best:
                best = total
    return best


def edge_cases() -> None:
    g = [[0, 1, -1], [1, 0, -1], [1, 1, 1]]
    assert cherry_pickup_dp(g) == 5
    assert cherry_pickup_iter(g) == 5
    assert cherry_pickup_brute(g) == 5
    assert cherry_pickup_dp([[1]]) == 1
    assert cherry_pickup_dp([[0]]) == 0
    assert cherry_pickup_dp([[-1]]) == 0  # blocked
    assert cherry_pickup_dp([[1, 1, 1], [1, -1, 1], [1, 1, 1]]) == 8  # two paths around
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(80):
        n = rng.randint(1, 4)
        g = [[rng.choice([0, 0, 1, 1, -1]) for _ in range(n)] for _ in range(n)]
        # ensure start/end not thorn for the comparison
        if g[0][0] == -1:
            g[0][0] = 0
        if g[n - 1][n - 1] == -1:
            g[n - 1][n - 1] = 0
        a = cherry_pickup_dp(g)
        b = cherry_pickup_iter(g)
        c = cherry_pickup_brute(g)
        assert a == b == c, f"trial {trial}: g={g} dp={a} iter={b} brute={c}"
    print("stress_test: 80 trials — memoized DP, iterative step-major DP, brute path-pair enum all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
