"""
Verification of Module 3 dimensional cascading-depletion results.
Renewal Libertarianism, Part 1, Chapter III, Module 3.

The Cascading Depletion Theorem and the Uneven Cascade Pattern proposition
formalize how depletion of any one legitimacy resource accelerates depletion
of others, and how that dynamic produces dimension-specific decline rates.

This script verifies the two formalized results:

  (a) Theorem [Cascading Depletion Across Dimensions]:
      For interior optima, partial c_k^* / partial L_j < 0 for all k != j.
      Depletion of any one resource accelerates depletion of others, and
      the detection capacity vector D(L) is monotonically increasing in
      each L_j -- so as resources deplete, detection falls in every
      dimension those resources support.

  (b) Prop [Uneven Cascade Pattern]:
      The rate of detection capacity decline in dimension d depends on
      the depletion rates of resources whose primary contribution is to
      dimension d. Different cascade patterns across resources produce
      different decline rates across dimensions.

The verification uses a tractable specification of substrate's per-period
capture-rate optimization across J resources. The cross-resource coupling
packages the manuscript's three channels (portfolio re-optimization,
refresh-cost, continuation-value), each of which contributes positively to
the manuscript's partial F_k / partial L_j > 0 and thus to the negative
partial c_k^* / partial L_j claim.

Run: uv run python cascading_depletion.py

Notes on what the verification does and does not show:

  1. The verification uses a closed-form specification of substrate's
     capture-rate optimum that exhibits the manuscript's three channels
     through one composite coupling alpha_{kj} / L_j (so depleting L_j
     raises the marginal benefit of capturing k). This is one specification
     consistent with the manuscript's signed channels; other specifications
     with the same channel signs give the same qualitative outcomes.

  2. Test (a) sweeps the central calibration (L vector and alpha entries)
     to confirm the sign claim is not a calibration artefact. Test (b)
     uses a stylized two-resources-per-dimension construction to make the
     uneven-cascade prediction visible cleanly.

  3. The detection-capacity claim D(L) monotonically increasing in each
     L_j follows from the linear D = W^T L specification with W >= 0; this
     matches the manuscript's resource-to-dimension mapping.
"""

import numpy as np

from _helpers import banner


# ----------------------------------------------------------------------
# Model: substrate's per-period capture-rate optimization across J resources
# ----------------------------------------------------------------------
# Per-period payoff for capturing resource j at rate c_j:
#
#     g_j(c_j; L) = c_j * h(L_j) * (1 + sum_{k != j} alpha_{jk} / L_k)
#
# with quadratic cost (1/2) * kappa * c_j^2.
#
# h(L_j) is the resource-specific marginal benefit (decreasing in L_j:
# diminishing returns to capturing an already-depleted resource). The cross
# term sum_{k != j} alpha_{jk} / L_k packages the manuscript's three channels:
#   - portfolio re-optimization: substrate's marginal benefit of capturing j
#     rises when other resources are depleted (so other resources offer less
#     restraint)
#   - refresh-cost: refresh efficiency for j falls when coordination-
#     supporting resources k != j are depleted
#   - continuation-value: marginal cost of further capture is lower when
#     complementary resources are depleted
#
# All three channels give partial g_j / partial L_k < 0 for k != j (the
# marginal benefit rises as L_k falls). Substrate's interior FOC:
#
#     dg_j / dc_j = kappa * c_j
# =>  c_j^*(L) = (1 / kappa) * h(L_j) * (1 + sum_{k != j} alpha_{jk} / L_k)
#
# Then partial c_k^* / partial L_j for j != k equals
#     -(1 / kappa) * h(L_k) * alpha_{kj} / L_j^2  <  0  for alpha_{kj} > 0,
# verifying the theorem's sign claim.
# ----------------------------------------------------------------------


def h(L):
    """Resource-specific marginal benefit; decreasing in L."""
    return 1.0 / (0.1 + L)


def make_alpha(J, alpha_off=0.1, rng=None):
    """Cross-resource interaction matrix. Diagonal entries are zero (the
    sum in c_j^* is over k != j); off-diagonals are positive."""
    if rng is None:
        A = np.full((J, J), alpha_off)
    else:
        A = rng.uniform(0.5 * alpha_off, 1.5 * alpha_off, (J, J))
    np.fill_diagonal(A, 0.0)
    return A


def optimal_capture(L, alpha, kappa=1.0):
    """Closed-form interior optimum c_j^*(L) at each resource j."""
    J = len(L)
    c_star = np.zeros(J)
    for j in range(J):
        cross_term = sum(alpha[j, k] / L[k] for k in range(J) if k != j)
        c_star[j] = (1.0 / kappa) * h(L[j]) * (1.0 + cross_term)
    return c_star


def numerical_partials(L, alpha, kappa=1.0, h_pert=1e-5):
    """Return partials[j, k] = partial c_k^* / partial L_j by finite difference."""
    c0 = optimal_capture(L, alpha, kappa)
    J = len(L)
    partials = np.zeros((J, J))
    for j in range(J):
        L_pert = L.copy()
        L_pert[j] += h_pert
        cP = optimal_capture(L_pert, alpha, kappa)
        partials[j, :] = (cP - c0) / h_pert
    return partials


# ----------------------------------------------------------------------
# (a) Theorem [Cascading Depletion Across Dimensions]
# ----------------------------------------------------------------------

def test_cascading_depletion_signs():
    banner("(a) Theorem [Cascading Depletion]: partial c_k^* / partial L_j < 0")

    print("\n  Computing partial c_k^* / partial L_j numerically by finite")
    print("  differences. Manuscript claim: the partial is strictly negative")
    print("  for every (j, k) pair with k != j.")

    J = 6
    L0 = np.full(J, 0.7)
    alpha = make_alpha(J)
    partials = numerical_partials(L0, alpha)

    violations = [
        (j, k, partials[j, k])
        for j in range(J) for k in range(J)
        if j != k and partials[j, k] >= 0.0
    ]

    print(f"\n  Central calibration: J = {J} resources at L_j = 0.7 for all j.")
    print(f"  Off-diagonal partials matrix [d c_k / d L_j], rows = j, cols = k:")
    print()
    for j in range(J):
        row = "      "
        for k in range(J):
            if j == k:
                row += "    --    "
            else:
                row += f"  {partials[j, k]:+8.5f}"
        print(row)

    assert not violations, (
        f"Sign claim violated at {len(violations)} (j, k) pairs: "
        f"{violations[:3]}"
    )

    print()
    print("  [pass] partial c_k^* / partial L_j < 0 for every (j, k) with k != j")


def test_cascading_depletion_sweep():
    banner("(a, robustness) Sign claim holds across parameter sweeps")

    print("\n  Sweeping the central L vector, the alpha off-diagonal magnitude,")
    print("  and the resource count J. The sign claim partial c_k^* /")
    print("  partial L_j < 0 should hold at every point.")

    total_pairs = 0
    violations = 0
    rng = np.random.default_rng(seed=20260528)

    for J in [3, 5, 8, 12]:
        for L_central in [0.4, 0.6, 0.7, 0.85]:
            for alpha_mag in [0.02, 0.05, 0.10, 0.20]:
                L0 = np.full(J, L_central)
                alpha = make_alpha(J, alpha_off=alpha_mag, rng=rng)
                partials = numerical_partials(L0, alpha)
                for j in range(J):
                    for k in range(J):
                        if k != j:
                            total_pairs += 1
                            if partials[j, k] >= 0.0:
                                violations += 1

    print(f"\n  Swept {total_pairs} (j, k) off-diagonal pairs across "
          f"4 x 4 x 4 = 64 calibrations.")
    print(f"  Sign violations: {violations}")
    assert violations == 0, (
        f"Sign claim violated at {violations} / {total_pairs} pairs"
    )

    print()
    print("  [pass] sign claim robust across L_central, alpha magnitude, and J")


def test_detection_monotonicity():
    banner("(a, corollary) Detection capacity D(L) monotonically increasing in L")

    print("\n  Corollary of the theorem: with D(L) = W^T L and W >= 0, each")
    print("  component D_d is non-decreasing in every L_j with w_{jd} > 0.")
    print("  As resources deplete (L_j falls), detection falls in every")
    print("  dimension that resource supports.")

    J = 6
    rng = np.random.default_rng(seed=42)
    W = rng.uniform(0.1, 1.0, (J, 3))
    L0 = np.full(J, 0.7)
    D0 = W.T @ L0

    print(f"\n  Random weight matrix W (J = {J} resources, 3 dimensions):")
    for j in range(J):
        print(f"      W[{j}] = ({W[j, 0]:.3f}, {W[j, 1]:.3f}, {W[j, 2]:.3f})")
    print(f"\n  Baseline detection D0 = ({D0[0]:.4f}, {D0[1]:.4f}, {D0[2]:.4f})")
    print()

    for j in range(J):
        L_lower = L0.copy()
        L_lower[j] = 0.30
        D_lower = W.T @ L_lower
        dD = D_lower - D0
        print(f"      L_{j} -> 0.30:  dD = ({dD[0]:+.4f}, {dD[1]:+.4f}, "
              f"{dD[2]:+.4f})")
        for d in range(3):
            assert dD[d] <= 1e-12, (
                f"Detection rose in dim {d} when L_{j} fell"
            )

    print()
    print("  [pass] D(L) monotonically decreasing in each L_j (corollary holds)")


# ----------------------------------------------------------------------
# (b) Prop [Uneven Cascade Pattern]
# ----------------------------------------------------------------------

def test_uneven_cascade():
    banner("(b) Prop [Uneven Cascade Pattern]: dimension-specific decline rates")

    print("\n  Two resources support each dimension primarily. When a polity's")
    print("  cascade depletes the resources supporting dimension d faster than")
    print("  others, dimension d's detection capacity declines fastest.")

    # Two resources primarily support each of three dimensions
    W = np.array([
        [1.0, 0.0, 0.0],  # resource 0: extraction-primary
        [1.0, 0.0, 0.0],  # resource 1: extraction-primary
        [0.0, 1.0, 0.0],  # resource 2: scope-primary
        [0.0, 1.0, 0.0],  # resource 3: scope-primary
        [0.0, 0.0, 1.0],  # resource 4: quality-primary
        [0.0, 0.0, 1.0],  # resource 5: quality-primary
    ])
    L0 = np.full(6, 0.7)
    D0 = W.T @ L0

    print(f"\n  Baseline: L = ({L0[0]:.1f},)*6,  D = ({D0[0]:.2f}, {D0[1]:.2f}, "
          f"{D0[2]:.2f})")
    print()

    scenarios = {
        "Late-Republican Rome (scope-primary deplete first)":   [2, 3, 0, 1, 4, 5],
        "Late-Han China (quality-primary deplete first)":       [4, 5, 0, 1, 2, 3],
        "Extraction-primary cascade (hypothetical)":            [0, 1, 2, 3, 4, 5],
    }

    L_fast = 0.30  # depleted-fast stock level
    L_slow = 0.55  # depleted-slow stock level

    print("  Three depletion patterns; first two resources of each pattern hit")
    print(f"  the fast stock level {L_fast}, next two hit the slow stock level "
          f"{L_slow},")
    print(f"  last two stay at baseline {L0[0]}.")
    print()

    for label, order in scenarios.items():
        L = L0.copy()
        # First two of the order deplete fastest
        L[order[0]] = L_fast
        L[order[1]] = L_fast
        # Next two deplete slowly
        L[order[2]] = L_slow
        L[order[3]] = L_slow
        # Last two stay
        D = W.T @ L
        dD = D - D0
        print(f"  {label}")
        print(f"      L      = {np.array2string(L, precision=2)}")
        print(f"      dD     = ({dD[0]:+.3f}, {dD[1]:+.3f}, {dD[2]:+.3f})")

        # The dimension whose primary resources depleted fastest should show
        # the largest decline.
        fastest_dim = np.argmax(W[order[0]])
        decline_magnitudes = -dD
        slowest_decline_in_fastest = decline_magnitudes[fastest_dim]
        for d in range(3):
            if d != fastest_dim:
                assert slowest_decline_in_fastest > decline_magnitudes[d] - 1e-9, (
                    f"Fastest-depleted dim {fastest_dim} not declining fastest "
                    f"in scenario '{label}'"
                )
        print(f"      Fastest decline in dim {fastest_dim} "
              f"(magnitude {decline_magnitudes[fastest_dim]:.3f}), as predicted.")
        print()

    print("  [pass] Different depletion patterns produce different fastest-")
    print("         declining dimensions, matching the resource-to-dimension")
    print("         mapping in each scenario")


# ----------------------------------------------------------------------

def main():
    banner("Cascading Depletion - numerical verification of Ch. III Module 3")
    print()
    print("  Verifies the two formalized results in Module 3:")
    print("      (a) Theorem [Cascading Depletion Across Dimensions]")
    print("          (with the detection-capacity monotonicity corollary)")
    print("      (b) Prop    [Uneven Cascade Pattern]")
    print()
    print("  Under a tractable specification of substrate's capture-rate")
    print("  optimization with the manuscript's three signed channels packaged")
    print("  as one cross-resource coupling alpha_{jk} / L_k.")

    test_cascading_depletion_signs()
    test_cascading_depletion_sweep()
    test_detection_monotonicity()
    test_uneven_cascade()

    banner("All verifications passed")
    print("""
  Result: Module 3's two formalized results hold on a tractable
  specification consistent with the manuscript's three signed channels.

    (a) Theorem [Cascading Depletion Across Dimensions]:
        partial c_k^* / partial L_j < 0 for every off-diagonal (j, k) at
        the central calibration and across a sweep of L_central, alpha
        magnitude, and resource count J.
        The detection-capacity corollary D(L) monotonically increasing in
        each L_j follows from the linear D = W^T L specification with
        W >= 0.

    (b) Prop [Uneven Cascade Pattern]:
        Different depletion patterns across resources produce different
        fastest-declining dimensions in the detection vector. The dimension
        whose primary supporting resources deplete fastest experiences the
        most rapid detection decline, as the manuscript predicts.

  The numerical magnitudes are calibration-dependent; the qualitative
  claims (sign of the partials, monotonicity of D in L, uneven decline
  across dimensions) are structural and confirmed.
""")


if __name__ == "__main__":
    main()
