01: Causal Fundamentalsยถ
โCorrelation does not imply causation.โ - Every Statistics 101 Student
Welcome to the world of causal inference! This notebook introduces the foundational concepts that separate correlation from causation. Weโll explore why most data science focuses on prediction while causal inference tackles the deeper questions of โwhyโ and โwhat ifโ.
๐ฏ Learning Objectivesยถ
By the end of this notebook, youโll understand:
The difference between correlation and causation
Counterfactual reasoning and potential outcomes
The fundamental problem of causal inference
Simpsonโs paradox and its implications
Association vs causation with real examples
๐ The Correlation vs Causation Problemยถ
Classic Example: Ice cream sales and drowning deaths both increase in summer. Does eating ice cream cause drowning?
This is the classic โcorrelation does not imply causationโ example. Both variables are correlated, but the relationship is spurious - theyโre both caused by a third variable (hot weather).
Types of Relationshipsยถ
Causation: A โ B (A causes B)
Reverse Causation: B โ A (B causes A)
Common Cause: C โ A and C โ B
Spurious Correlation: No causal relationship
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')
# Set random seeds
np.random.seed(42)
# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12]
print("Causal Inference Fundamentals")
print("=============================")
def demonstrate_spurious_correlation():
"""Demonstrate spurious correlation with ice cream and drowning example"""
print("=== Spurious Correlation: Ice Cream & Drowning ===\n")
# Generate synthetic data
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
n_months = len(months)
# Temperature (confounder)
temperature = np.array([10, 12, 15, 20, 25, 30, 35, 33, 28, 22, 15, 10])
# Ice cream sales (caused by temperature)
ice_cream_sales = 50 + 2 * temperature + np.random.normal(0, 5, n_months)
# Drowning deaths (caused by temperature)
drowning_deaths = 5 + 0.3 * temperature + np.random.normal(0, 2, n_months)
# Create DataFrame
data = pd.DataFrame({
'Month': months,
'Temperature': temperature,
'Ice_Cream_Sales': ice_cream_sales,
'Drowning_Deaths': drowning_deaths
})
# Calculate correlations
corr_ice_drown = data['Ice_Cream_Sales'].corr(data['Drowning_Deaths'])
corr_temp_ice = data['Temperature'].corr(data['Ice_Cream_Sales'])
corr_temp_drown = data['Temperature'].corr(data['Drowning_Deaths'])
print(f"Correlation between Ice Cream Sales and Drowning Deaths: {corr_ice_drown:.3f}")
print(f"Correlation between Temperature and Ice Cream Sales: {corr_temp_ice:.3f}")
print(f"Correlation between Temperature and Drowning Deaths: {corr_temp_drown:.3f}")
# Visualize the relationships
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
# Time series plot
ax1.plot(data['Month'], data['Ice_Cream_Sales'], 'b-o', label='Ice Cream Sales')
ax1.plot(data['Month'], data['Drowning_Deaths'], 'r-s', label='Drowning Deaths')
ax1.set_title('Monthly Trends: Ice Cream vs Drowning')
ax1.set_xlabel('Month')
ax1.set_ylabel('Count')
ax1.legend()
ax1.tick_params(axis='x', rotation=45)
# Scatter plot: Ice cream vs Drowning (spurious correlation)
ax2.scatter(data['Ice_Cream_Sales'], data['Drowning_Deaths'], alpha=0.7)
ax2.set_title(f'Ice Cream Sales vs Drowning Deaths\nCorrelation: {corr_ice_drown:.3f}')
ax2.set_xlabel('Ice Cream Sales')
ax2.set_ylabel('Drowning Deaths')
# Add trend line
z = np.polyfit(data['Ice_Cream_Sales'], data['Drowning_Deaths'], 1)
p = np.poly1d(z)
ax2.plot(data['Ice_Cream_Sales'], p(data['Ice_Cream_Sales']), "r--", alpha=0.8)
# Scatter plot: Temperature vs Ice Cream (causal)
ax3.scatter(data['Temperature'], data['Ice_Cream_Sales'], alpha=0.7)
ax3.set_title(f'Temperature vs Ice Cream Sales\nCorrelation: {corr_temp_ice:.3f}')
ax3.set_xlabel('Temperature (ยฐC)')
ax3.set_ylabel('Ice Cream Sales')
# Scatter plot: Temperature vs Drowning (causal)
ax4.scatter(data['Temperature'], data['Drowning_Deaths'], alpha=0.7)
ax4.set_title(f'Temperature vs Drowning Deaths\nCorrelation: {corr_temp_drown:.3f}')
ax4.set_xlabel('Temperature (ยฐC)')
ax4.set_ylabel('Drowning Deaths')
plt.tight_layout()
plt.show()
print("\n๐ Analysis:")
print("- Ice cream sales and drowning deaths are highly correlated")
print("- But neither causes the other directly")
print("- Both are caused by temperature (confounding variable)")
print("- This is a classic example of spurious correlation")
return data
# Demonstrate spurious correlation
spurious_data = demonstrate_spurious_correlation()
๐ญ Counterfactual Reasoningยถ
Counterfactual: โWhat would have happened ifโฆ?โ
Causal inference is fundamentally about counterfactuals. For each individual, we want to know:
What was their outcome with treatment? (observed)
What would their outcome have been without treatment? (counterfactual)
The Fundamental Problem of Causal Inferenceยถ
We can never observe both potential outcomes for the same individual simultaneously. This is the core challenge of causal inference.
Notation:
\(Y_i(1)\): Outcome if individual i receives treatment
\(Y_i(0)\): Outcome if individual i does not receive treatment
\(T_i\): Treatment indicator (1 if treated, 0 if not)
Observed outcome: \(Y_i = T_i \cdot Y_i(1) + (1-T_i) \cdot Y_i(0)\)
Individual Treatment Effect: \(ITE_i = Y_i(1) - Y_i(0)\) Average Treatment Effect: \(ATE = E[Y_i(1) - Y_i(0)]\)
def demonstrate_counterfactuals():
"""Demonstrate counterfactual reasoning with potential outcomes"""
print("=== Counterfactual Reasoning & Potential Outcomes ===\n")
# Simulate a simple treatment effect
np.random.seed(42)
n_people = 1000
# Generate potential outcomes
# Y(0): outcome without treatment (baseline)
# Y(1): outcome with treatment
baseline_outcome = np.random.normal(50, 10, n_people)
treatment_effect = np.random.normal(5, 2, n_people) # Average effect of 5
# Potential outcomes
Y0 = baseline_outcome # Outcome without treatment
Y1 = baseline_outcome + treatment_effect # Outcome with treatment
# Random treatment assignment (RCT)
treatment_assignment = np.random.binomial(1, 0.5, n_people)
# Observed outcomes
observed_outcomes = treatment_assignment * Y1 + (1 - treatment_assignment) * Y0
# Create DataFrame
causal_data = pd.DataFrame({
'Person_ID': range(n_people),
'Treated': treatment_assignment,
'Y0': Y0, # Counterfactual: what would have happened without treatment
'Y1': Y1, # Counterfactual: what would have happened with treatment
'Y_observed': observed_outcomes,
'ITE': Y1 - Y0 # Individual treatment effect
})
# Calculate causal effects
treated_group = causal_data[causal_data['Treated'] == 1]
control_group = causal_data[causal_data['Treated'] == 0]
# Naive comparison (observed outcomes)
naive_effect = treated_group['Y_observed'].mean() - control_group['Y_observed'].mean()
# True causal effects
true_ate = causal_data['ITE'].mean()
true_att = treated_group['ITE'].mean() # Average treatment effect on treated
true_atc = control_group['ITE'].mean() # Average treatment effect on controls
print(f"Sample Size: {n_people} people")
print(f"Treatment Rate: {treatment_assignment.mean():.1%}")
print()
print("Causal Effect Estimates:")
print(f"Naive comparison (observed): {naive_effect:.2f}")
print(f"True ATE (all individuals): {true_ate:.2f}")
print(f"True ATT (treated only): {true_att:.2f}")
print(f"True ATC (controls only): {true_atc:.2f}")
print()
# Show individual examples
print("Individual Examples:")
examples = causal_data.sample(5, random_state=42)
for _, person in examples.iterrows():
treated = "Treated" if person['Treated'] else "Control"
observed = person['Y_observed']
counterfactual = person['Y1'] if not person['Treated'] else person['Y0']
ite = person['ITE']
print(f"Person {int(person['Person_ID']):3d} ({treated:7}): "
f"Observed: {observed:5.1f}, "
f"Counterfactual: {counterfactual:5.1f}, "
f"ITE: {ite:+5.1f}")
# Visualize the fundamental problem
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Plot potential outcomes
ax1.scatter(causal_data['Y0'], causal_data['Y1'], alpha=0.6, c=causal_data['Treated'],
cmap='coolwarm', s=50)
ax1.plot([causal_data['Y0'].min(), causal_data['Y0'].max()],
[causal_data['Y0'].min(), causal_data['Y0'].max()], 'k--', alpha=0.7)
ax1.set_xlabel('Outcome without Treatment Y(0)')
ax1.set_ylabel('Outcome with Treatment Y(1)')
ax1.set_title('Potential Outcomes Framework\n(Fundamental Problem of Causal Inference)')
ax1.grid(True, alpha=0.3)
# Add colorbar
cbar = plt.colorbar(ax1.collections[0], ax=ax1)
cbar.set_label('Treatment Status')
cbar.set_ticks([0, 1])
cbar.set_ticklabels(['Control', 'Treated'])
# Plot individual treatment effects
ax2.hist(causal_data['ITE'], bins=30, alpha=0.7, edgecolor='black')
ax2.axvline(causal_data['ITE'].mean(), color='red', linestyle='--', linewidth=2,
label=f'Mean ITE: {causal_data["ITE"].mean():.2f}')
ax2.set_xlabel('Individual Treatment Effect (ITE)')
ax2.set_ylabel('Frequency')
ax2.set_title('Distribution of Treatment Effects')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\n๐ Key Insights:")
print("1. We never observe both Y(0) and Y(1) for the same person")
print("2. The diagonal line shows where Y(1) = Y(0) (no treatment effect)")
print("3. Points above the line benefit from treatment")
print("4. Treatment effects vary across individuals")
print("5. RCTs help us estimate average effects despite this fundamental problem")
return causal_data
# Demonstrate counterfactuals
counterfactual_data = demonstrate_counterfactuals()
๐ค Simpsonโs Paradoxยถ
Simpsonโs Paradox: A trend appears in different groups of data but disappears or reverses when the groups are combined.
Classic Example: Berkeley Gender Bias Case (1973)
When looking at individual departments, women had higher admission rates
When looking at the university overall, men had higher admission rates
This happened because women applied to more competitive departments
Why This Matters for Causal Inferenceยถ
Simpsonโs paradox shows how aggregation can hide or reverse causal relationships. It demonstrates the importance of:
Stratification: Looking at subgroups separately
Confounding: Controlling for important variables
Careful analysis: Not jumping to conclusions from aggregate data
def demonstrate_simpsons_paradox():
"""Demonstrate Simpson's paradox with Berkeley admissions example"""
print("=== Simpson's Paradox: Berkeley Admissions ===\n")
# Create synthetic Berkeley admissions data
# Two departments with different competitiveness
# Department A: Less competitive (higher admission rates)
dept_a_men_applied = 1000
dept_a_men_admitted = 800 # 80% admission rate
dept_a_women_applied = 100
dept_a_women_admitted = 90 # 90% admission rate
# Department B: More competitive (lower admission rates)
dept_b_men_applied = 100
dept_b_men_admitted = 10 # 10% admission rate
dept_b_women_applied = 1000
dept_b_women_admitted = 450 # 45% admission rate
# Create DataFrame
admissions_data = pd.DataFrame({
'Department': ['A', 'A', 'B', 'B'],
'Gender': ['Men', 'Women', 'Men', 'Women'],
'Applied': [dept_a_men_applied, dept_a_women_applied,
dept_b_men_applied, dept_b_women_applied],
'Admitted': [dept_a_men_admitted, dept_a_women_admitted,
dept_b_men_admitted, dept_b_women_admitted]
})
admissions_data['Admission_Rate'] = (admissions_data['Admitted'] / admissions_data['Applied'] * 100).round(1)
# Calculate aggregate statistics
total_men_applied = admissions_data[admissions_data['Gender'] == 'Men']['Applied'].sum()
total_men_admitted = admissions_data[admissions_data['Gender'] == 'Men']['Admitted'].sum()
total_women_applied = admissions_data[admissions_data['Gender'] == 'Women']['Applied'].sum()
total_women_admitted = admissions_data[admissions_data['Gender'] == 'Women']['Admitted'].sum()
overall_men_rate = (total_men_admitted / total_men_applied * 100).round(1)
overall_women_rate = (total_women_admitted / total_women_applied * 100).round(1)
print("Berkeley Admissions Data:")
print(admissions_data.to_string(index=False))
print()
print("Admission Rates by Department and Gender:")
pivot_table = admissions_data.pivot(index='Department', columns='Gender', values='Admission_Rate')
print(pivot_table)
print()
print(f"Overall Admission Rates:")
print(f"Men: {overall_men_rate}%")
print(f"Women: {overall_women_rate}%")
print()
print("๐ญ Simpson's Paradox:")
print("- In Department A: Women have higher admission rate (90% > 80%)")
print("- In Department B: Women have higher admission rate (45% > 10%)")
print("- Overall: Men appear to have higher admission rate (44.4% > 41.7%)")
print()
# Visualize the paradox
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Department-level view
dept_data = admissions_data.copy()
dept_data['Label'] = dept_data['Department'] + '\n' + dept_data['Gender']
bars1 = ax1.bar(range(len(dept_data)), dept_data['Admission_Rate'],
color=['lightblue', 'pink', 'lightblue', 'pink'])
ax1.set_xticks(range(len(dept_data)))
ax1.set_xticklabels(dept_data['Label'])
ax1.set_ylabel('Admission Rate (%)')
ax1.set_title('Admission Rates by Department and Gender\n(Simpson\'s Paradox)')
ax1.set_ylim(0, 100)
# Add value labels
for bar, rate in zip(bars1, dept_data['Admission_Rate']):
height = bar.get_height()
ax1.text(bar.get_x() + bar.get_width()/2., height + 1,
f'{rate}%', ha='center', va='bottom')
# Overall view
overall_data = pd.DataFrame({
'Gender': ['Men', 'Women'],
'Admission_Rate': [overall_men_rate, overall_women_rate]
})
bars2 = ax2.bar(range(len(overall_data)), overall_data['Admission_Rate'],
color=['lightblue', 'pink'])
ax2.set_xticks(range(len(overall_data)))
ax2.set_xticklabels(overall_data['Gender'])
ax2.set_ylabel('Admission Rate (%)')
ax2.set_title('Overall Admission Rates\n(Aggregated Data)')
ax2.set_ylim(0, 50)
# Add value labels
for bar, rate in zip(bars2, overall_data['Admission_Rate']):
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height + 0.5,
f'{rate}%', ha='center', va='bottom')
plt.tight_layout()
plt.show()
print("\n๐ Why This Happens:")
print("1. Women applied disproportionately to the more competitive Department B")
print("2. Department B has lower admission rates overall")
print("3. When aggregated, this creates the appearance of gender bias")
print("4. The 'bias' disappears when we look at departments separately")
print()
print("๐ Causal Inference Lessons:")
print("- Always check for confounding variables (like department competitiveness)")
print("- Stratification can reveal hidden patterns")
print("- Aggregate statistics can be misleading")
print("- Context matters for causal interpretation")
return admissions_data
# Demonstrate Simpson's paradox
simpsons_data = demonstrate_simpsons_paradox()
Association vs Causation: Real-World Examplesยถ
Understanding the difference between association and causation is one of the most important skills in data science. An observed correlation between variables \(X\) and \(Y\) can arise from at least four distinct mechanisms: direct causation (\(X \to Y\)), reverse causation (\(Y \to X\)), a common cause or confounder (\(Z \to X\) and \(Z \to Y\)), or pure coincidence. The code below walks through five well-known examples โ from coffee and longevity to vaccines and autism โ and then uses synthetic data to show how controlling for a confounder (health consciousness) can weaken or eliminate what initially appears to be a strong relationship. This technique of stratification is the simplest form of confounding adjustment and motivates the more formal methods (regression, matching, instrumental variables) covered in later notebooks.
def explore_causal_relationships():
"""Explore different types of causal relationships"""
print("=== Association vs Causation: Real-World Examples ===\n")
examples = [
{
'name': 'Coffee Consumption & Longevity',
'correlation': 'Coffee drinkers live longer',
'true_relationship': 'Common cause - healthier lifestyle',
'type': 'Confounding',
'evidence': 'After controlling for diet, exercise, etc., effect diminishes'
},
{
'name': 'Education & Income',
'correlation': 'Higher education leads to higher income',
'true_relationship': 'Causal (with some reverse causation)',
'type': 'Causation',
'evidence': 'RCTs show education interventions increase earnings'
},
{
'name': 'Ice Cream & Crime',
'correlation': 'Ice cream sales predict crime rates',
'true_relationship': 'Common cause - warm weather',
'type': 'Confounding',
'evidence': 'Crime rates follow temperature, not ice cream consumption'
},
{
'name': 'Poverty & Mental Health',
'correlation': 'Poor people have more mental health issues',
'true_relationship': 'Bidirectional causation',
'type': 'Reverse + Direct Causation',
'evidence': 'Longitudinal studies show both directions of effect'
},
{
'name': 'Vaccines & Autism',
'correlation': 'Vaccinated children have higher autism rates',
'true_relationship': 'No causation - age-related diagnosis',
'type': 'Spurious',
'evidence': 'Multiple large studies show no causal link'
}
]
print("Examples of Association vs Causation:")
print("=" * 60)
for i, example in enumerate(examples, 1):
print(f"\n{i}. {example['name']}")
print(f" Correlation: {example['correlation']}")
print(f" True Relationship: {example['true_relationship']}")
print(f" Type: {example['type']}")
print(f" Evidence: {example['evidence']}")
# Demonstrate with synthetic data
print("\n\n๐ Synthetic Data Demonstration:")
# Example: Coffee consumption and longevity
n_people = 1000
# Confounder: Health consciousness
health_consciousness = np.random.normal(0, 1, n_people)
# Coffee consumption (affected by health consciousness)
coffee_cups = 2 + health_consciousness + np.random.normal(0, 0.5, n_people)
coffee_cups = np.maximum(0, coffee_cups) # No negative coffee
# Longevity (affected by health consciousness)
longevity = 75 + 2 * health_consciousness + np.random.normal(0, 5, n_people)
# Create DataFrame
health_data = pd.DataFrame({
'Coffee_Cups_Per_Day': coffee_cups,
'Longevity_Years': longevity,
'Health_Consciousness': health_consciousness
})
# Calculate correlations
corr_coffee_longevity = health_data['Coffee_Cups_Per_Day'].corr(health_data['Longevity_Years'])
corr_health_coffee = health_data['Health_Consciousness'].corr(health_data['Coffee_Cups_Per_Day'])
corr_health_longevity = health_data['Health_Consciousness'].corr(health_data['Longevity_Years'])
print(f"Correlation (Coffee โ Longevity): {corr_coffee_longevity:.3f}")
print(f"Correlation (Health โ Coffee): {corr_health_coffee:.3f}")
print(f"Correlation (Health โ Longevity): {corr_health_longevity:.3f}")
# Visualize
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Raw correlation (spurious)
ax1.scatter(health_data['Coffee_Cups_Per_Day'], health_data['Longevity_Years'], alpha=0.6)
ax1.set_xlabel('Coffee Cups Per Day')
ax1.set_ylabel('Longevity (Years)')
ax1.set_title(f'Coffee vs Longevity\nCorrelation: {corr_coffee_longevity:.3f}')
ax1.grid(True, alpha=0.3)
# Add trend line
z = np.polyfit(health_data['Coffee_Cups_Per_Day'], health_data['Longevity_Years'], 1)
p = np.poly1d(z)
x_trend = np.linspace(health_data['Coffee_Cups_Per_Day'].min(), health_data['Coffee_Cups_Per_Day'].max(), 100)
ax1.plot(x_trend, p(x_trend), "r--", alpha=0.8)
# Controlling for confounder
# Split by health consciousness tertiles
health_data['Health_Tertile'] = pd.qcut(health_data['Health_Consciousness'], 3, labels=['Low', 'Medium', 'High'])
colors = ['red', 'orange', 'green']
for i, (name, group) in enumerate(health_data.groupby('Health_Tertile')):
ax2.scatter(group['Coffee_Cups_Per_Day'], group['Longevity_Years'],
alpha=0.6, color=colors[i], label=f'{name} Health Consciousness')
ax2.set_xlabel('Coffee Cups Per Day')
ax2.set_ylabel('Longevity (Years)')
ax2.set_title('Coffee vs Longevity\n(Controlling for Health Consciousness)')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
print("\n๐ Key Insight:")
print("- Raw correlation suggests coffee causes longevity")
print("- After controlling for health consciousness, the relationship weakens")
print("- This demonstrates the importance of accounting for confounders")
return health_data
# Explore causal relationships
causal_examples = explore_causal_relationships()
๐ฏ Key Takeawaysยถ
1. Correlation โ Causationยถ
Just because two things move together doesnโt mean one causes the other
Always look for confounding variables and alternative explanations
2. Counterfactual Reasoningยถ
Causal inference is about โwhat would have happened ifโฆโ
We can never observe both potential outcomes for the same individual
This creates the fundamental problem of causal inference
3. Simpsonโs Paradoxยถ
Aggregate statistics can hide or reverse true relationships
Always check subgroups and control for important variables
4. Types of Relationshipsยถ
Direct Causation: A โ B
Reverse Causation: B โ A
Common Cause: C โ A and C โ B
Spurious: No causal relationship
5. Causal Questionsยถ
Prediction: โWhat will happen?โ (correlation is often sufficient)
Causation: โWhy does this happen?โ โWhat if we intervene?โ (needs causal methods)
๐ Critical Thinking Questionsยถ
Can you think of a correlation in your field that might be spurious?
How would you design an experiment to test if education causes higher income?
Whatโs an example where reverse causation might be occurring?
How can Simpsonโs paradox affect business decisions?
๐ Next Stepsยถ
Now that you understand the fundamentals, youโre ready to dive deeper into:
Causal Graphs & DAGs: Visualizing causal relationships
Experimental Design: How to conduct valid causal studies
Observational Methods: Estimating causal effects from non-experimental data
Remember: Causal inference is about understanding mechanisms, not just patterns. Itโs the difference between describing the world and understanding how to change it!
โThe correlation between A and B is not necessarily the result of A causing B, or B causing A, or both being caused by C. It could be that the correlation is a coincidence.โ - Judea Pearl