"""
p40 — N-Queens (LeetCode 51, HARD) + LC 52 count variant.

Implementations:
    solve_n_queens           — three-set backtracking; returns boards.
    total_n_queens           — LC 52 count via three-set backtracking (no boards).
    solve_n_queens_bitmask   — bitmask count (Knuth-style; fastest).
    solve_n_queens_brute     — exhaustive C(N²,N) for tiny N (oracle for stress).

Run: python3 solution.py
"""
from __future__ import annotations
from itertools import combinations
from typing import List

# Known counts (https://oeis.org/A000170)
KNOWN_COUNTS = [1, 1, 0, 0, 2, 10, 4, 40, 92, 352, 724]


# --------------------------------------------------------------------------------------
# LC 51: return all distinct N-Queens boards.
# Three-set invariants: cols, diag = r+c, anti_diag = r-c.
# Row-by-row recursion (one queen per row by construction).
# --------------------------------------------------------------------------------------
def solve_n_queens(n: int) -> List[List[str]]:
    result: List[List[str]] = []
    queens: List[int] = [0] * n   # queens[r] = column of queen in row r
    cols: set = set()
    diag: set = set()        # r + c
    anti: set = set()        # r - c

    def backtrack(r: int) -> None:
        if r == n:
            # Build board strings ONLY on success (cheap; happens few times)
            board = []
            for c in queens:
                row = ['.'] * n
                row[c] = 'Q'
                board.append(''.join(row))
            result.append(board)
            return
        for c in range(n):
            if c in cols or (r + c) in diag or (r - c) in anti:
                continue
            # INVARIANT: (r, c) is unattacked by any previously placed queen.
            queens[r] = c
            cols.add(c); diag.add(r + c); anti.add(r - c)
            backtrack(r + 1)
            cols.remove(c); diag.remove(r + c); anti.remove(r - c)

    backtrack(0)
    return result


# --------------------------------------------------------------------------------------
# LC 52: count solutions only. Same recursion, just increment a counter.
# --------------------------------------------------------------------------------------
def total_n_queens(n: int) -> int:
    cols: set = set()
    diag: set = set()
    anti: set = set()
    count = 0

    def backtrack(r: int) -> None:
        nonlocal count
        if r == n:
            count += 1
            return
        for c in range(n):
            if c in cols or (r + c) in diag or (r - c) in anti:
                continue
            cols.add(c); diag.add(r + c); anti.add(r - c)
            backtrack(r + 1)
            cols.remove(c); diag.remove(r + c); anti.remove(r - c)

    backtrack(0)
    return count


# --------------------------------------------------------------------------------------
# Bitmask version (Knuth). Fastest classical N-Queens.
# Encode each set as an int. Lowest set bit = leftmost free column.
# Shift diag left and anti right each recursion to propagate diagonals.
# --------------------------------------------------------------------------------------
def solve_n_queens_bitmask(n: int) -> int:
    if n == 0:
        return 1
    end = (1 << n) - 1
    count = 0

    def rec(cols: int, diag: int, anti: int) -> None:
        nonlocal count
        if cols == end:
            count += 1
            return
        # Free columns: not in any attacker mask
        free = ~(cols | diag | anti) & end
        while free:
            bit = free & -free      # lowest set bit
            free ^= bit             # clear it
            # INVARIANT: diag << 1 / anti >> 1 propagates diagonal attacks to next row
            rec(cols | bit, (diag | bit) << 1, (anti | bit) >> 1)

    rec(0, 0, 0)
    return count


# --------------------------------------------------------------------------------------
# Brute oracle: try every C(n², n) placement, check pairwise non-attack.
# Use for n ≤ 5 only.
# --------------------------------------------------------------------------------------
def solve_n_queens_brute(n: int) -> int:
    if n == 0:
        return 1
    count = 0
    cells = [(r, c) for r in range(n) for c in range(n)]
    for placement in combinations(cells, n):
        ok = True
        for i in range(n):
            r1, c1 = placement[i]
            for j in range(i + 1, n):
                r2, c2 = placement[j]
                if r1 == r2 or c1 == c2 or (r1 + c1) == (r2 + c2) or (r1 - c1) == (r2 - c2):
                    ok = False
                    break
            if not ok:
                break
        if ok:
            count += 1
    return count


# --------------------------------------------------------------------------------------
# Tests
# --------------------------------------------------------------------------------------
def _verify_board(board: List[str], n: int) -> bool:
    """Sanity-check a single solution board."""
    if len(board) != n:
        return False
    queens = []
    for r, row in enumerate(board):
        if len(row) != n or row.count('Q') != 1:
            return False
        c = row.index('Q')
        queens.append((r, c))
    for i in range(n):
        r1, c1 = queens[i]
        for j in range(i + 1, n):
            r2, c2 = queens[j]
            if r1 == r2 or c1 == c2 or (r1 + c1) == (r2 + c2) or (r1 - c1) == (r2 - c2):
                return False
    return True


def edge_cases() -> None:
    # n = 1
    s1 = solve_n_queens(1)
    assert s1 == [["Q"]], s1

    # n = 2, 3 → no solutions
    assert solve_n_queens(2) == []
    assert solve_n_queens(3) == []

    # n = 4 → 2 solutions
    s4 = solve_n_queens(4)
    assert len(s4) == 2
    for b in s4:
        assert _verify_board(b, 4)

    # Counts match the known sequence
    for n in range(len(KNOWN_COUNTS)):
        got = total_n_queens(n if n > 0 else 1)  # skip n=0 (LC undefined)
        # KNOWN_COUNTS[0] is the n=0 case; check against the actual n
        # Realign: just test n in [1..len-1]
        if n == 0:
            continue
        assert got == KNOWN_COUNTS[n], f"n={n}: got {got}, want {KNOWN_COUNTS[n]}"

    # Bitmask agrees with set-based count for n in [1..9]
    for n in range(1, 10):
        a = total_n_queens(n)
        b = solve_n_queens_bitmask(n)
        assert a == b, f"n={n}: set={a}, bitmask={b}"

    # All solutions returned by solve_n_queens are valid
    for n in range(1, 7):
        for board in solve_n_queens(n):
            assert _verify_board(board, n), f"invalid board for n={n}: {board}"

    print("edge_cases: PASS")


def stress_test() -> None:
    # Compare brute oracle for small n (brute scales as C(n², n) — keep n ≤ 5)
    for n in range(1, 6):
        a = total_n_queens(n)
        b = solve_n_queens_bitmask(n)
        c = solve_n_queens_brute(n)
        assert a == b == c, f"DISAGREE n={n}: set={a}, bitmask={b}, brute={c}"

    # For larger n, bitmask vs three-set
    for n in range(6, 10):
        a = total_n_queens(n)
        b = solve_n_queens_bitmask(n)
        assert a == b == KNOWN_COUNTS[n], f"n={n}: set={a}, bitmask={b}, known={KNOWN_COUNTS[n]}"

    print("stress_test: counts match for n=1..9 across set-based, bitmask, brute (n≤5), and OEIS.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
