"""
p80 — Number of Ways to Wear Different Hats (LeetCode 1434, HARD).

KEY: people <= 10 => bitmask people (2^10 = 1024). Hats = 40 outer.

dp[h][mask] = ways using hats 1..h to satisfy exactly the people in mask.
Transition:
    dp[h][mask] = dp[h-1][mask]                              # skip hat h
                + sum over p in liked_by[h] if mask & (1<<p) of dp[h-1][mask ^ (1<<p)]
Answer: dp[40][(1<<n) - 1] mod 1e9+7.
"""
from __future__ import annotations
import random
from functools import lru_cache
from typing import List

MOD = 10 ** 9 + 7
MAX_HATS = 40


def number_ways_dp(hats: List[List[int]]) -> int:
    n = len(hats)
    full = (1 << n) - 1
    liked_by: List[List[int]] = [[] for _ in range(MAX_HATS + 1)]
    for p, hs in enumerate(hats):
        for h in hs:
            liked_by[h].append(p)
    # 1D rolling: dp[mask]. For each hat, build a new array (safest).
    dp = [0] * (full + 1)
    dp[0] = 1
    for h in range(1, MAX_HATS + 1):
        new_dp = dp[:]  # skip hat h
        for p in liked_by[h]:
            bit = 1 << p
            for mask in range(full + 1):
                if mask & bit:
                    new_dp[mask] = (new_dp[mask] + dp[mask ^ bit]) % MOD
        dp = new_dp
    return dp[full]


def number_ways_memo(hats: List[List[int]]) -> int:
    n = len(hats)
    full = (1 << n) - 1
    liked_by: List[List[int]] = [[] for _ in range(MAX_HATS + 2)]
    for p, hs in enumerate(hats):
        for h in hs:
            liked_by[h].append(p)

    @lru_cache(maxsize=None)
    def solve(h: int, mask: int) -> int:
        if mask == full:
            return 1
        if h > MAX_HATS:
            return 0
        total = solve(h + 1, mask)  # skip
        for p in liked_by[h]:
            if not (mask & (1 << p)):
                total = (total + solve(h + 1, mask | (1 << p))) % MOD
        return total

    return solve(1, 0)


def number_ways_brute(hats: List[List[int]]) -> int:
    # Recursive assign person-by-person; track used hats as a set.
    n = len(hats)
    used: set = set()
    count = 0

    def assign(p: int) -> None:
        nonlocal count
        if p == n:
            count = (count + 1) % MOD
            return
        for h in hats[p]:
            if h not in used:
                used.add(h)
                assign(p + 1)
                used.remove(h)

    assign(0)
    return count


def edge_cases() -> None:
    assert number_ways_dp([[3, 4], [4, 5], [5]]) == 1
    assert number_ways_dp([[3, 5, 1], [3, 5]]) == 4
    assert number_ways_dp([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]) == 24
    assert number_ways_dp([[1, 2, 3], [2, 3, 5, 6], [1, 3, 7, 9], [1, 8, 9], [2, 5, 7]]) == 111
    assert number_ways_memo([[3, 4], [4, 5], [5]]) == 1
    assert number_ways_brute([[3, 4], [4, 5], [5]]) == 1
    print("edge_cases: PASS")


def stress_test() -> None:
    rng = random.Random(42)
    for trial in range(150):
        n = rng.randint(1, 4)
        max_hat = rng.randint(1, 6)
        hats = []
        for _ in range(n):
            k = rng.randint(1, max_hat)
            sel = rng.sample(range(1, max_hat + 1), k)
            hats.append(sel)
        a = number_ways_dp(hats)
        b = number_ways_memo(hats)
        c = number_ways_brute(hats)
        assert a == b == c, f"trial {trial}: hats={hats} dp={a} memo={b} brute={c}"
    print("stress_test: 150 trials — iterative DP, memoized, brute assignment all agree.")


if __name__ == "__main__":
    edge_cases()
    stress_test()
    print("ALL TESTS PASSED")
