Try an interactive version of this dialog: Sign up at solve.it.com, click Upload, and pass this URL.

Causal Inference in Python - Part 3: Effect Heterogeneity and Personalization

Data Generation

import numpy as np
import pandas as pd

np.random.seed(42)
n = 2000

struggling = np.random.binomial(1, 0.4, n)
prior_grade = np.clip(np.random.normal(70 - 12*struggling, 10, n), 0, 100)
grade_level = np.random.choice([7, 8, 9], n)
tutoring = np.random.binomial(1, 0.5, n)

grade_level_effect = dict(zip([7,8,9], [0, 2, 4]))
baseline = 0.8*prior_grade + 20 - 14*struggling + np.vectorize(grade_level_effect.get)(grade_level)

tent = np.maximum(0, 1 - np.abs(prior_grade - 65) / 12)
cate_s = 6 + 10 * tent

quad = np.maximum(0, 1 - ((prior_grade - 72) / 15)**2)
cate_ns = 2 + 4 * (grade_level == 9) * quad

true_cate = np.where(struggling, cate_s, cate_ns)
final_grade = np.clip(baseline + true_cate*tutoring + np.random.normal(0, 5, n), 0, 100)

df = pd.DataFrame(dict(struggling=struggling, prior_grade=prior_grade.round(1),
    grade_level=grade_level, tutoring=tutoring,
    true_cate=true_cate.round(3), final_grade=final_grade.round(1)))
df.head()
struggling prior_grade grade_level tutoring true_cate final_grade
0 0 61.2 7 0 2.000 69.4
1 1 49.7 7 0 6.000 38.0
2 1 55.7 8 0 8.279 53.3
3 0 73.7 7 1 2.000 85.3
4 0 79.1 8 1 2.000 81.1
import matplotlib.pyplot as plt
from itertools import product

fig, ax = plt.subplots(figsize=(9, 4))
colors = {(0,7):'#aec6e8', (0,8):'#4a90d9', (0,9):'#1a4f8a', (1,7):'#f4a582', (1,8):'#d6604d', (1,9):'#8b1a1a'}
markers = {7:'o', 8:'s', 9:'^'}
for s, g in product([0,1], [7,8,9]):
    mask = (df.struggling==s) & (df.grade_level==g)
    label = f"{'Struggling' if s else 'Not struggling'}, G{g}"
    ax.scatter(df.loc[mask,'prior_grade'], df.loc[mask,'true_cate'],
               alpha=0.4, s=12, color=colors[(s,g)], marker=markers[g], label=label)
ax.set(xlabel='Prior Grade', ylabel='True CATE', title='True CATE vs Prior Grade')
ax.legend(fontsize=7, ncol=2)
plt.tight_layout()
plt.show()
print(df.final_grade.describe())
print(df.groupby('struggling').final_grade.mean())
count    2000.000000
mean       70.962000
std        14.422361
min        27.800000
25%        61.000000
50%        72.200000
75%        81.600000
max       100.000000
Name: final_grade, dtype: float64
struggling
0    78.981826
1    59.349449
Name: final_grade, dtype: float64
from sklearn.ensemble import RandomForestRegressor

X,y,T =['struggling', 'prior_grade', 'grade_level'], 'final_grade', 'tutoring'

rf = RandomForestRegressor(random_state=42).fit(df[X + [T]], df[y])
df['ml_pred'] = rf.predict(df[X + [T]])

m0 = RandomForestRegressor(random_state=42).fit(df.query(f"{T}==0")[X], df.query(f"{T}==0")[y])
m1 = RandomForestRegressor(random_state=42).fit(df.query(f"{T}==1")[X], df.query(f"{T}==1")[y])
df['t_learner_pred'] = m1.predict(df[X]) - m0.predict(df[X])

df['random_pred'] = np.random.default_rng(42).uniform(-10, 10, len(df))
df.head()
struggling prior_grade grade_level tutoring true_cate final_grade ml_pred t_learner_pred random_pred
0 0 61.2 7 0 2.000 69.4 69.626500 9.685067 5.479121
1 1 49.7 7 0 6.000 38.0 43.674583 13.761405 -1.222431
2 1 55.7 8 0 8.279 53.3 53.966833 7.750133 7.171958
3 0 73.7 7 1 2.000 85.3 85.652833 8.923217 3.947361
4 0 79.1 8 1 2.000 81.1 86.163090 -3.689333 -8.116453

Tool 1: Effect-by-Quantile

from fklearn.causal.validation.curves import effect_by_segment

fig, axs = plt.subplots(1, 3, figsize=(12, 4), sharey=True)
for ax, m in zip(axs,['random_pred', 'ml_pred', 't_learner_pred']): 
    eff = effect_by_segment(df, 'tutoring', 'final_grade', m)
    ax.bar(range(len(eff)), eff.values)
    ax.set(title=m, xlabel='Quantile')
axs[0].set_ylabel('Estimated Effect')
plt.tight_layout()
/app/data/.local/lib/python3.12/site-packages/fklearn/causal/validation/curves.py:53: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  .groupby(f"{prediction}_band")
/app/data/.local/lib/python3.12/site-packages/fklearn/causal/validation/curves.py:54: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  .apply(effect_fn_partial)
/app/data/.local/lib/python3.12/site-packages/fklearn/causal/validation/curves.py:53: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  .groupby(f"{prediction}_band")
/app/data/.local/lib/python3.12/site-packages/fklearn/causal/validation/curves.py:54: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  .apply(effect_fn_partial)
/app/data/.local/lib/python3.12/site-packages/fklearn/causal/validation/curves.py:53: FutureWarning: The default of observed=False is deprecated and will be changed to True in a future version of pandas. Pass observed=False to retain current behavior or observed=True to adopt the future default and silence this warning.
  .groupby(f"{prediction}_band")
/app/data/.local/lib/python3.12/site-packages/fklearn/causal/validation/curves.py:54: FutureWarning: DataFrameGroupBy.apply operated on the grouping columns. This behavior is deprecated, and in a future version of pandas the grouping columns will be excluded from the operation. Either pass `include_groups=False` to exclude the groupings or explicitly select the grouping columns after groupby to silence this warning.
  .apply(effect_fn_partial)

Tool 2: Cumulative Gain Curves

from fklearn.causal.validation.curves import cumulative_gain_curve, relative_cumulative_gain_curve

fig, axs = plt.subplots(1, 2, figsize=(12, 4))
for m in['random_pred', 'ml_pred', 't_learner_pred']:
    cg = cumulative_gain_curve(df, 'tutoring', 'final_grade', m)
    rcg = relative_cumulative_gain_curve(df, 'tutoring', 'final_grade', m)
    x_pct = np.linspace(0, 100, len(cg))
    axs[0].plot(x_pct, cg, label=m)
    axs[1].plot(x_pct, rcg, label=m)
axs[0].plot([0, 100],[0, cumulative_gain_curve(df, 'tutoring', 'final_grade', 't_learner_pred')[-1]], 'k--', label='baseline')
axs[1].axhline(0, color='k', linestyle='--', label='baseline')
axs[0].set(xlabel='Top %', ylabel='Cumulative Gain'); axs[1].set(xlabel='Top %', ylabel='Relative Cumulative Gain')
for ax in axs: ax.legend()
plt.tight_layout()
def auc(df, t, y, m): return relative_cumulative_gain_curve(df, t, y, m).mean()
for m in ['random_pred', 'ml_pred', 't_learner_pred']: print(f"{m} AUC: {auc(df, 'tutoring', 'final_grade', m):.2f}")
random_pred AUC: 0.45
ml_pred AUC: -1.74
t_learner_pred AUC: 1.89

Tool 3: Target Transformation

import statsmodels.formula.api as smf
from sklearn.metrics import mean_squared_error

# 1. Isolate the residuals for outcome and treatment

X_cols = "struggling + prior_grade + C(grade_level)"
y_res = smf.ols(f"final_grade ~ {X_cols}", data=df).fit().resid
t_res = smf.ols(f"tutoring ~ {X_cols}", data=df).fit().resid

# 2. Construct the Y* target and the stabilizing weights

y_star = y_res / t_res
weights = t_res**2

# 3. Calculate weighted MSE for each model

for m in ['random_pred', 'ml_pred', 't_learner_pred']:
    wmse = mean_squared_error(y_star, df[m], sample_weight=weights)
    print(f"{m} Weighted MSE: {wmse:.2f}")
random_pred Weighted MSE: 195.06
ml_pred Weighted MSE: 4603.56
t_learner_pred Weighted MSE: 75.63