"""
Audience Formation Sensitivity Check (v2)

Tests whether the framework's key qualitative comparative statics survive
under different audience-aggregation rules.

Three aggregation rules tested:
  1. Reduced form: r(tau) = Phi((tau - tau_c)/sqrt(v))  [current model]
  2. Bayesian-with-noise: each agent forms posterior, revolt iff median > tau_c
  3. Threshold-cascade: heterogeneous thresholds with social contagion

Four predictions tested:
  - Diversity helps: longer time to terminal capture as J increases
  - Detection helps: lower long-run extraction as v decreases
  - Refresh helps: higher equilibrium stock as mu increases
  - Rotation helps: higher equilibrium stock as eta increases
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from matplotlib.patches import FancyBboxPatch

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.size'] = 11
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 13
plt.rcParams['mathtext.fontset'] = 'cm'

import os
OUTPUT_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'figures')
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ============================================================================
# Aggregation rules
# ============================================================================

def revolt_prob_reduced_form(tau, tau_c, v):
    """Original reduced form: r(tau) = Phi((tau - tau_c)/sqrt(v))"""
    z = (tau - tau_c) / np.sqrt(max(v, 0.01))
    return stats.norm.cdf(z)


def revolt_prob_bayesian(tau, tau_c, v, N=100, prior_var=10.0, n_trials=100):
    """
    Bayesian-with-noise aggregation.
    Each agent: signal s_i = tau + eta_i, eta_i ~ N(0, v).
    Posterior mean: weighted average of signal and prior tau_c.
    Revolt iff median posterior > tau_c.
    """
    revolts = 0
    for _ in range(n_trials):
        signals = tau + np.random.normal(0, np.sqrt(max(v, 0.01)), N)
        posterior_means = (prior_var * signals + v * tau_c) / (prior_var + v)
        if np.median(posterior_means) > tau_c:
            revolts += 1
    return revolts / n_trials


def revolt_prob_threshold_cascade(tau, tau_c, v, N=100, threshold_spread=2.0,
                                    cascade_threshold=0.3, social_resistance=1.5,
                                    n_trials=100):
    """
    Threshold-cascade aggregation with genuine contagion dynamics.
    
    Mechanism:
    - Each agent has individual threshold tau_i ~ N(tau_c, threshold_spread)
    - Each receives noisy signal s_i = tau + eta_i
    - Initially, only agents with s_i > tau_i + social_resistance revolt
      (social resistance is a barrier against revolting alone)
    - Each iteration: if fraction of revolters >= cascade_threshold,
      social resistance drops to zero for the next round, allowing all
      agents with s_i > tau_i to join. This produces genuine contagion:
      the cascade triggers when initial revolters reach critical mass,
      then sweeps through the rest of the population.
    - If cascade does not trigger, only the initial revolters remain
    - Revolt counts when final revolter fraction > 0.5
    """
    revolts = 0
    for _ in range(n_trials):
        thresholds = np.random.normal(tau_c, threshold_spread, N)
        signals = tau + np.random.normal(0, np.sqrt(max(v, 0.01)), N)
        
        # Initial revolters: those who would revolt despite social resistance
        revolters = signals > (thresholds + social_resistance)
        
        # Iterate cascade dynamics
        cascade_triggered = False
        for _ in range(20):
            fraction = revolters.mean()
            
            if fraction >= cascade_threshold and not cascade_triggered:
                # Cascade triggers: social resistance drops, all with s > tau join
                cascade_triggered = True
                new_revolters = signals > thresholds
            elif cascade_triggered:
                # Cascade already triggered: state stable
                new_revolters = revolters
            else:
                # Below cascade threshold: no contagion, state stable
                new_revolters = revolters
            
            if np.array_equal(new_revolters, revolters):
                break
            revolters = new_revolters
        
        if revolters.mean() > 0.5:
            revolts += 1
    return revolts / n_trials


# ============================================================================
# Stock dynamics simulation
# ============================================================================

def simulate_stock(aggregation_fn, T=400, J=4, refresh_rate=0.020, rotation_rate=0.02,
                    capture_intensity=0.30, tau_c_init=5.0, v_base=8.0,
                    seed=42, **agg_kwargs):
    """
    Simulate stock trajectory under given aggregation rule.
    
    Parameters chosen so that the system is genuinely under stress:
    - capture_intensity high enough that audience response actually matters
    - refresh_rate moderate so substrate's choice can drive depletion
    - The substrate's effective capture depends on (1 - revolt_prob), so
      different audience aggregations produce materially different equilibria
    """
    np.random.seed(seed)
    L = 1.0
    L_history = [L]
    
    for t in range(T):
        # Effective opacity: increases as stocks deplete, decreases with diversity
        v_eff = v_base / np.sqrt(J / 2) * (1 + (1 - L) * 1.5)
        # Substrate's attempted extraction (more aggressive as stock depletes)
        tau_attempt = tau_c_init * (1 + (1 - L) * 1.0)
        
        # Audience response under chosen aggregation
        revolt_prob = aggregation_fn(tau_attempt, tau_c_init, v_eff, **agg_kwargs)
        
        # Substrate's per-resource capture rate
        # Spread across J resources but each one still substantive
        per_resource_capture = (capture_intensity / np.sqrt(J)) * (1 - revolt_prob) * L
        
        # Rotation reduces effective capture (limited to 80% reduction max)
        rotation_factor = max(1 - rotation_rate * 5, 0.2)
        effective_capture = per_resource_capture * rotation_factor
        
        # Refresh works against capture
        L_new = L + refresh_rate * (1 - L) - effective_capture
        L = max(min(L_new, 1.0), 0.05)
        L_history.append(L)
    
    return np.array(L_history)


def time_to_threshold(L_history, threshold=0.4):
    """Time until stock falls below threshold (or T if it never does)."""
    below = L_history < threshold
    if not below.any():
        return len(L_history)
    return np.argmax(below)


# ============================================================================
# Run the four tests
# ============================================================================

print("=" * 70)
print("TEST 1: Diversity helps (long-run stock increases with J)")
print("=" * 70)
print("Note: We measure equilibrium stock rather than time-to-threshold,")
print("because the alternative aggregations sustain stocks high enough")
print("that time-to-threshold is uninformative for them.")

J_range = np.array([2, 4, 6, 8, 10])
results_diversity = {'J': J_range, 'reduced': [], 'bayesian': [], 'cascade': []}

for J in J_range:
    L_rf = simulate_stock(revolt_prob_reduced_form, J=J)
    L_bay = simulate_stock(revolt_prob_bayesian, J=J, n_trials=50)
    L_cas = simulate_stock(revolt_prob_threshold_cascade, J=J, n_trials=50)
    
    # Use long-run equilibrium stock instead of time-to-threshold
    results_diversity['reduced'].append(L_rf[-50:].mean())
    results_diversity['bayesian'].append(L_bay[-50:].mean())
    results_diversity['cascade'].append(L_cas[-50:].mean())
    print(f"  J={J}: RF={results_diversity['reduced'][-1]:.3f}, "
          f"Bay={results_diversity['bayesian'][-1]:.3f}, "
          f"Cas={results_diversity['cascade'][-1]:.3f}")


print("\n" + "=" * 70)
print("TEST 2: Detection helps (long-run stock decreases with opacity v)")
print("=" * 70)

v_range = np.array([4.0, 8.0, 16.0, 32.0])
results_opacity = {'v': v_range, 'reduced': [], 'bayesian': [], 'cascade': []}

for v in v_range:
    L_rf = simulate_stock(revolt_prob_reduced_form, v_base=v)
    L_bay = simulate_stock(revolt_prob_bayesian, v_base=v, n_trials=50)
    L_cas = simulate_stock(revolt_prob_threshold_cascade, v_base=v, n_trials=50)
    
    results_opacity['reduced'].append(L_rf[-50:].mean())
    results_opacity['bayesian'].append(L_bay[-50:].mean())
    results_opacity['cascade'].append(L_cas[-50:].mean())
    print(f"  v={v}: RF={results_opacity['reduced'][-1]:.3f}, "
          f"Bay={results_opacity['bayesian'][-1]:.3f}, "
          f"Cas={results_opacity['cascade'][-1]:.3f}")


print("\n" + "=" * 70)
print("TEST 3: Refresh helps (equilibrium stock increases with mu)")
print("=" * 70)

mu_range = np.array([0.005, 0.015, 0.025, 0.035, 0.05])
results_refresh = {'mu': mu_range, 'reduced': [], 'bayesian': [], 'cascade': []}

for mu in mu_range:
    L_rf = simulate_stock(revolt_prob_reduced_form, refresh_rate=mu)
    L_bay = simulate_stock(revolt_prob_bayesian, refresh_rate=mu, n_trials=50)
    L_cas = simulate_stock(revolt_prob_threshold_cascade, refresh_rate=mu, n_trials=50)
    
    results_refresh['reduced'].append(L_rf[-50:].mean())
    results_refresh['bayesian'].append(L_bay[-50:].mean())
    results_refresh['cascade'].append(L_cas[-50:].mean())
    print(f"  mu={mu:.3f}: RF={results_refresh['reduced'][-1]:.3f}, "
          f"Bay={results_refresh['bayesian'][-1]:.3f}, "
          f"Cas={results_refresh['cascade'][-1]:.3f}")


print("\n" + "=" * 70)
print("TEST 4: Rotation helps (equilibrium stock increases with eta)")
print("=" * 70)

eta_range = np.array([0.005, 0.02, 0.04, 0.08, 0.12])
results_rotation = {'eta': eta_range, 'reduced': [], 'bayesian': [], 'cascade': []}

for eta in eta_range:
    L_rf = simulate_stock(revolt_prob_reduced_form, rotation_rate=eta)
    L_bay = simulate_stock(revolt_prob_bayesian, rotation_rate=eta, n_trials=50)
    L_cas = simulate_stock(revolt_prob_threshold_cascade, rotation_rate=eta, n_trials=50)
    
    results_rotation['reduced'].append(L_rf[-50:].mean())
    results_rotation['bayesian'].append(L_bay[-50:].mean())
    results_rotation['cascade'].append(L_cas[-50:].mean())
    print(f"  eta={eta:.3f}: RF={results_rotation['reduced'][-1]:.3f}, "
          f"Bay={results_rotation['bayesian'][-1]:.3f}, "
          f"Cas={results_rotation['cascade'][-1]:.3f}")


# ============================================================================
# Compute slopes and verdicts
# ============================================================================

def compute_slope(xs, ys):
    xs = np.array(xs)
    ys = np.array(ys)
    return (ys[-1] - ys[0]) / (xs[-1] - xs[0])

slopes = {
    'diversity': {
        'reduced': compute_slope(results_diversity['J'], results_diversity['reduced']),
        'bayesian': compute_slope(results_diversity['J'], results_diversity['bayesian']),
        'cascade': compute_slope(results_diversity['J'], results_diversity['cascade']),
        'predicted_sign': '>',
        'metric': 'long-run stock'
    },
    'opacity': {
        'reduced': compute_slope(results_opacity['v'], results_opacity['reduced']),
        'bayesian': compute_slope(results_opacity['v'], results_opacity['bayesian']),
        'cascade': compute_slope(results_opacity['v'], results_opacity['cascade']),
        'predicted_sign': '<',
        'metric': 'long-run stock'
    },
    'refresh': {
        'reduced': compute_slope(results_refresh['mu'], results_refresh['reduced']),
        'bayesian': compute_slope(results_refresh['mu'], results_refresh['bayesian']),
        'cascade': compute_slope(results_refresh['mu'], results_refresh['cascade']),
        'predicted_sign': '>',
        'metric': 'long-run stock'
    },
    'rotation': {
        'reduced': compute_slope(results_rotation['eta'], results_rotation['reduced']),
        'bayesian': compute_slope(results_rotation['eta'], results_rotation['bayesian']),
        'cascade': compute_slope(results_rotation['eta'], results_rotation['cascade']),
        'predicted_sign': '>',
        'metric': 'long-run stock'
    }
}

def check_sign(value, sign):
    if sign == '>':
        return value > 0.001
    else:
        return value < -0.001

print("\n" + "=" * 70)
print("VERDICT SUMMARY")
print("=" * 70)
for test, data in slopes.items():
    rf_ok = check_sign(data['reduced'], data['predicted_sign'])
    bay_ok = check_sign(data['bayesian'], data['predicted_sign'])
    cas_ok = check_sign(data['cascade'], data['predicted_sign'])
    consensus = rf_ok and bay_ok and cas_ok
    
    print(f"\n{test.upper()} (predicted: slope {data['predicted_sign']} 0):")
    print(f"  Reduced form:      slope = {data['reduced']:.4f} ({'PASS' if rf_ok else 'FAIL'})")
    print(f"  Bayesian:          slope = {data['bayesian']:.4f} ({'PASS' if bay_ok else 'FAIL'})")
    print(f"  Threshold-cascade: slope = {data['cascade']:.4f} ({'PASS' if cas_ok else 'FAIL'})")
    print(f"  CONSENSUS: {'YES - prediction robust' if consensus else 'NO - aggregation-dependent'}")


# ============================================================================
# Generate figures
# ============================================================================

# Figure 1: Four-panel comparative statics
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

ax = axes[0, 0]
ax.plot(results_diversity['J'], results_diversity['reduced'], 'o-',
        color='#1f77b4', linewidth=2.2, markersize=9, label='Reduced form')
ax.plot(results_diversity['J'], results_diversity['bayesian'], 's-',
        color='#2ca02c', linewidth=2.2, markersize=9, label='Bayesian-with-noise')
ax.plot(results_diversity['J'], results_diversity['cascade'], '^-',
        color='#d62728', linewidth=2.2, markersize=9, label='Threshold-cascade')
ax.set_xlabel('Resource diversity $J$')
ax.set_ylabel('Long-run stock $L(\\infty)$')
ax.set_title('Test 1: Diversity\nFramework predicts: $\\partial L / \\partial J > 0$',
             fontweight='bold')
ax.legend(loc='lower right', fontsize=9, framealpha=0.95)
ax.grid(True, alpha=0.3)
ax.set_ylim(0, 1.05)

ax = axes[0, 1]
ax.plot(results_opacity['v'], results_opacity['reduced'], 'o-',
        color='#1f77b4', linewidth=2.2, markersize=9, label='Reduced form')
ax.plot(results_opacity['v'], results_opacity['bayesian'], 's-',
        color='#2ca02c', linewidth=2.2, markersize=9, label='Bayesian-with-noise')
ax.plot(results_opacity['v'], results_opacity['cascade'], '^-',
        color='#d62728', linewidth=2.2, markersize=9, label='Threshold-cascade')
ax.set_xlabel('Opacity $v$')
ax.set_ylabel('Long-run stock $L(\\infty)$')
ax.set_title('Test 2: Detection\nFramework predicts: $\\partial L / \\partial v < 0$',
             fontweight='bold')
ax.legend(loc='upper right', fontsize=9, framealpha=0.95)
ax.grid(True, alpha=0.3)

ax = axes[1, 0]
ax.plot(results_refresh['mu'], results_refresh['reduced'], 'o-',
        color='#1f77b4', linewidth=2.2, markersize=9, label='Reduced form')
ax.plot(results_refresh['mu'], results_refresh['bayesian'], 's-',
        color='#2ca02c', linewidth=2.2, markersize=9, label='Bayesian-with-noise')
ax.plot(results_refresh['mu'], results_refresh['cascade'], '^-',
        color='#d62728', linewidth=2.2, markersize=9, label='Threshold-cascade')
ax.set_xlabel('Refresh rate $\\mu_j$')
ax.set_ylabel('Long-run stock $L(\\infty)$')
ax.set_title('Test 3: Refresh\nFramework predicts: $\\partial L / \\partial \\mu > 0$',
             fontweight='bold')
ax.legend(loc='lower right', fontsize=9, framealpha=0.95)
ax.grid(True, alpha=0.3)

ax = axes[1, 1]
ax.plot(results_rotation['eta'], results_rotation['reduced'], 'o-',
        color='#1f77b4', linewidth=2.2, markersize=9, label='Reduced form')
ax.plot(results_rotation['eta'], results_rotation['bayesian'], 's-',
        color='#2ca02c', linewidth=2.2, markersize=9, label='Bayesian-with-noise')
ax.plot(results_rotation['eta'], results_rotation['cascade'], '^-',
        color='#d62728', linewidth=2.2, markersize=9, label='Threshold-cascade')
ax.set_xlabel('Rotation rate $\\eta_j$')
ax.set_ylabel('Long-run stock $L(\\infty)$')
ax.set_title('Test 4: Rotation\nFramework predicts: $\\partial L / \\partial \\eta > 0$',
             fontweight='bold')
ax.legend(loc='lower right', fontsize=9, framealpha=0.95)
ax.grid(True, alpha=0.3)

plt.suptitle('Audience-Formation Sensitivity Check: Four Comparative Statics',
             fontsize=14, fontweight='bold', y=1.00)
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/audience_sensitivity_comparative_statics.png',
            dpi=150, bbox_inches='tight')
plt.close()
print("\nFigure 'audience_sensitivity_comparative_statics.png' saved")


# Figure 2: Verdict summary
fig, ax = plt.subplots(figsize=(13, 7))
ax.set_xlim(0, 14)
ax.set_ylim(0, 11)
ax.axis('off')

ax.text(7, 10.4, 'Audience-Formation Sensitivity Verdict',
        fontsize=15, fontweight='bold', ha='center')
ax.text(7, 9.7, 'Do the framework\'s qualitative predictions survive under different aggregation rules?',
        fontsize=11, ha='center', style='italic', color='#666')

header_y = 8.5
ax.add_patch(FancyBboxPatch((0.3, header_y - 0.4), 13.4, 0.8,
                            boxstyle="round,pad=0.05",
                            facecolor='#374151', edgecolor='black', linewidth=1.5))
headers = ['Comparative Static', 'Reduced Form', 'Bayesian-with-noise',
           'Threshold-cascade', 'Verdict']
header_x = [1.5, 4.5, 7.0, 9.5, 12.0]
for x, h in zip(header_x, headers):
    ax.text(x, header_y, h, ha='center', va='center',
            fontsize=10, fontweight='bold', color='white')

test_labels = {
    'diversity': 'Diversity ($J$)',
    'opacity': 'Detection ($v$)',
    'refresh': 'Refresh ($\\mu$)',
    'rotation': 'Rotation ($\\eta$)'
}

row_y = header_y - 1.0
for test_name, label in test_labels.items():
    data = slopes[test_name]
    rf_ok = check_sign(data['reduced'], data['predicted_sign'])
    bay_ok = check_sign(data['bayesian'], data['predicted_sign'])
    cas_ok = check_sign(data['cascade'], data['predicted_sign'])
    consensus = rf_ok and bay_ok and cas_ok
    
    box_color = '#d4f4dd' if consensus else '#ffe5cc'
    border_color = '#15803d' if consensus else '#cc7700'
    
    ax.add_patch(FancyBboxPatch((0.3, row_y - 0.5), 13.4, 1.0,
                                boxstyle="round,pad=0.04",
                                facecolor=box_color, edgecolor=border_color,
                                linewidth=1.3))
    
    ax.text(1.5, row_y, label, ha='center', va='center', fontsize=10, fontweight='bold')
    
    rf_label = 'PASS' if rf_ok else 'FAIL'
    bay_label = 'PASS' if bay_ok else 'FAIL'
    cas_label = 'PASS' if cas_ok else 'FAIL'
    
    for x, slope, ok_label, ok in [(4.5, data['reduced'], rf_label, rf_ok),
                                     (7.0, data['bayesian'], bay_label, bay_ok),
                                     (9.5, data['cascade'], cas_label, cas_ok)]:
        ax.text(x, row_y + 0.18, f'slope = {slope:.3f}',
                ha='center', va='center', fontsize=9)
        ax.text(x, row_y - 0.18, ok_label,
                ha='center', va='center', fontsize=9, fontweight='bold',
                color='#15803d' if ok else '#cc1111')
    
    verdict = 'ROBUST' if consensus else 'AGGREGATION-\nDEPENDENT'
    verdict_color = '#15803d' if consensus else '#cc7700'
    ax.text(12.0, row_y, verdict, ha='center', va='center',
            fontsize=10, fontweight='bold', color=verdict_color)
    
    row_y -= 1.2

all_robust = all(
    check_sign(slopes[t]['reduced'], slopes[t]['predicted_sign']) and
    check_sign(slopes[t]['bayesian'], slopes[t]['predicted_sign']) and
    check_sign(slopes[t]['cascade'], slopes[t]['predicted_sign'])
    for t in slopes
)

if all_robust:
    conclusion = ('All four comparative statics survive both alternative aggregation rules.\n'
                  'The audience-formation gap appears to hide micro-detail rather than first-order effects.\n'
                  'The framework\'s scope choice (treating audience properties as observable primitives) is supported.')
    bg_color = '#e6f4ea'
    border = '#15803d'
else:
    conclusion = ('Some comparative statics depend on the aggregation rule.\n'
                  'The framework\'s qualitative predictions are not fully invariant to audience-formation assumptions.\n'
                  'The scope statement should specify which aggregation properties are required.')
    bg_color = '#fff4e6'
    border = '#cc7700'

ax.text(7, 0.7, conclusion, ha='center', va='center',
        fontsize=10.5, style='italic', fontweight='bold',
        bbox=dict(boxstyle='round,pad=0.5', facecolor=bg_color,
                 edgecolor=border, linewidth=1.5, alpha=0.95))

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/audience_sensitivity_verdict_summary.png',
            dpi=150, bbox_inches='tight')
plt.close()
print("Figure 'audience_sensitivity_verdict_summary.png' saved")

print("\nDone.")
