Imbalanced Datasets: Handling Class Imbalance ProperlyΒΆ
When 95% of samples are class 0, accuracy is meaningless. Learn SMOTE, class weights, threshold tuning, and the right evaluation metrics.
1. The Accuracy ParadoxΒΆ
Imagine a fraud detection system where 99% of transactions are legit. A model that always predicts βnot fraudβ achieves 99% accuracy β and catches zero frauds.
This is the accuracy paradox. Letβs demonstrate it.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
accuracy_score, classification_report, confusion_matrix,
roc_auc_score, average_precision_score
)
from sklearn.dummy import DummyClassifier
# Create a severely imbalanced dataset: 99% class 0, 1% class 1
X, y = make_classification(
n_samples=10000,
n_features=20,
n_informative=5,
weights=[0.99, 0.01], # 99:1 imbalance
random_state=42
)
print(f"Class distribution: {np.bincount(y)}")
print(f"Imbalance ratio: {np.bincount(y)[0] / np.bincount(y)[1]:.0f}:1")
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# The "dumb" model: always predict majority class
dummy = DummyClassifier(strategy='most_frequent')
dummy.fit(X_train, y_train)
dummy_preds = dummy.predict(X_test)
# Logistic Regression without any imbalance handling
lr = LogisticRegression(random_state=42, max_iter=1000)
lr.fit(X_train, y_train)
lr_preds = lr.predict(X_test)
print("\n--- Always-Majority (Dummy) Classifier ---")
print(f"Accuracy: {accuracy_score(y_test, dummy_preds):.4f} β looks great!")
print(f"Recall (fraud): {0:.4f} β catches ZERO fraud cases")
print("\n--- Logistic Regression (no balancing) ---")
print(classification_report(y_test, lr_preds, target_names=['Legit', 'Fraud']))
2. The Right Metrics for Imbalanced DataΒΆ
Metric |
Formula |
Use When |
|---|---|---|
Precision |
TP / (TP + FP) |
False alarms are costly (e.g., spam filter) |
Recall |
TP / (TP + FN) |
Missing positives is costly (e.g., cancer detection) |
F1 |
2 Γ P Γ R / (P + R) |
Balance precision and recall |
ROC-AUC |
Area under ROC curve |
Overall ranking ability |
PR-AUC |
Area under PR curve |
Better than ROC-AUC for severe imbalance |
from sklearn.metrics import precision_recall_curve, roc_curve
lr_proba = lr.predict_proba(X_test)[:, 1]
precision, recall, pr_thresholds = precision_recall_curve(y_test, lr_proba)
fpr, tpr, roc_thresholds = roc_curve(y_test, lr_proba)
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# PR Curve
ax = axes[0]
ap = average_precision_score(y_test, lr_proba)
ax.plot(recall, precision, color='steelblue', lw=2, label=f'PR curve (AP={ap:.3f})')
ax.axhline(np.bincount(y_test)[1]/len(y_test), color='red', linestyle='--', label=f'Baseline (random) = {np.bincount(y_test)[1]/len(y_test):.3f}')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_title('Precision-Recall Curve')
ax.legend()
ax.grid(True, alpha=0.3)
# ROC Curve
ax = axes[1]
auc = roc_auc_score(y_test, lr_proba)
ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC={auc:.3f})')
ax.plot([0, 1], [0, 1], 'k--', label='Random classifier')
ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')
ax.set_title('ROC Curve')
ax.legend()
ax.grid(True, alpha=0.3)
plt.suptitle('Evaluation Metrics for Imbalanced Classification', fontsize=13)
plt.tight_layout()
plt.show()
print(f"ROC-AUC: {auc:.4f}")
print(f"PR-AUC (Average Precision): {ap:.4f}")
print("\nFor severe imbalance, PR-AUC is more informative than ROC-AUC.")
print("ROC-AUC can look good even when precision is very low.")
3. class_weight=βbalancedβ: The Quick FixΒΆ
from sklearn.utils.class_weight import compute_class_weight
# class_weight='balanced' automatically sets weights inversely proportional to class frequencies
classes = np.unique(y_train)
weights = compute_class_weight('balanced', classes=classes, y=y_train)
print(f"Computed class weights: class 0 β {weights[0]:.4f}, class 1 β {weights[1]:.4f}")
print(f"Ratio: {weights[1]/weights[0]:.0f}:1 (mirrors the 99:1 imbalance)")
lr_balanced = LogisticRegression(class_weight='balanced', random_state=42, max_iter=1000)
lr_balanced.fit(X_train, y_train)
balanced_preds = lr_balanced.predict(X_test)
balanced_proba = lr_balanced.predict_proba(X_test)[:, 1]
print("\n--- Logistic Regression (class_weight='balanced') ---")
print(classification_report(y_test, balanced_preds, target_names=['Legit', 'Fraud']))
print(f"ROC-AUC: {roc_auc_score(y_test, balanced_proba):.4f}")
print(f"PR-AUC: {average_precision_score(y_test, balanced_proba):.4f}")
4. SMOTE: Synthetic Minority OversamplingΒΆ
try:
from imblearn.over_sampling import SMOTE
from imblearn.pipeline import Pipeline as ImbPipeline
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)
print(f"Before SMOTE: {np.bincount(y_train)}")
print(f"After SMOTE: {np.bincount(y_resampled)}")
lr_smote = LogisticRegression(random_state=42, max_iter=1000)
lr_smote.fit(X_resampled, y_resampled)
smote_preds = lr_smote.predict(X_test)
smote_proba = lr_smote.predict_proba(X_test)[:, 1]
print("\n--- Logistic Regression + SMOTE ---")
print(classification_report(y_test, smote_preds, target_names=['Legit', 'Fraud']))
smote_available = True
except ImportError:
print("imbalanced-learn not installed. Install with: pip install imbalanced-learn")
print("\nHow SMOTE works conceptually:")
print("1. For each minority sample, find k nearest neighbors (default k=5)")
print("2. Randomly select one neighbor")
print("3. Create a synthetic sample on the line segment between them")
print("4. Repeat until minority class is balanced")
# Manual oversampling as fallback
from sklearn.utils import resample
X_minority = X_train[y_train == 1]
y_minority = y_train[y_train == 1]
X_majority = X_train[y_train == 0]
y_majority = y_train[y_train == 0]
X_minority_up, y_minority_up = resample(
X_minority, y_minority,
n_samples=len(y_majority), random_state=42
)
X_resampled = np.vstack([X_majority, X_minority_up])
y_resampled = np.concatenate([y_majority, y_minority_up])
print(f"\nManual oversampling result: {np.bincount(y_resampled)}")
lr_over = LogisticRegression(random_state=42, max_iter=1000)
lr_over.fit(X_resampled, y_resampled)
over_preds = lr_over.predict(X_test)
print("\n--- Logistic Regression + Manual Oversampling ---")
print(classification_report(y_test, over_preds, target_names=['Legit', 'Fraud']))
smote_available = False
SMOTE gotcha: Always apply SMOTE only on training data, never on test data. If using pipelines with imblearn, use imblearn.pipeline.Pipeline, not sklearnβs β it respects the fit/transform distinction for samplers.
5. Threshold Tuning: Find the Optimal Decision BoundaryΒΆ
from sklearn.metrics import f1_score, precision_score, recall_score
# Use the balanced model's probabilities
thresholds = np.linspace(0.01, 0.99, 200)
f1_scores = []
precision_scores = []
recall_scores = []
for thresh in thresholds:
preds_at_thresh = (balanced_proba >= thresh).astype(int)
f1_scores.append(f1_score(y_test, preds_at_thresh, zero_division=0))
precision_scores.append(precision_score(y_test, preds_at_thresh, zero_division=0))
recall_scores.append(recall_score(y_test, preds_at_thresh, zero_division=0))
best_thresh_idx = np.argmax(f1_scores)
best_threshold = thresholds[best_thresh_idx]
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(thresholds, f1_scores, label='F1', color='green', lw=2)
ax.plot(thresholds, precision_scores, label='Precision', color='blue', lw=2)
ax.plot(thresholds, recall_scores, label='Recall', color='orange', lw=2)
ax.axvline(best_threshold, color='red', linestyle='--',
label=f'Best F1 threshold = {best_threshold:.3f}')
ax.axvline(0.5, color='gray', linestyle=':', alpha=0.7, label='Default threshold = 0.5')
ax.set_xlabel('Decision Threshold')
ax.set_ylabel('Score')
ax.set_title('Threshold Tuning: Precision-Recall-F1 Tradeoff')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Apply optimal threshold
final_preds = (balanced_proba >= best_threshold).astype(int)
print(f"Best threshold: {best_threshold:.3f}")
print(f"At best threshold:")
print(classification_report(y_test, final_preds, target_names=['Legit', 'Fraud']))
6. Comparison TableΒΆ
from sklearn.metrics import f1_score, recall_score
# Evaluate all approaches
approaches = {
'Baseline (no fix)': (lr_preds, lr.predict_proba(X_test)[:, 1]),
'class_weight=balanced': (balanced_preds, balanced_proba),
'Threshold tuning (balanced)': (final_preds, balanced_proba),
}
results = []
for name, (preds, proba) in approaches.items():
results.append({
'Approach': name,
'Accuracy': f"{accuracy_score(y_test, preds):.3f}",
'Precision (fraud)': f"{precision_score(y_test, preds, zero_division=0):.3f}",
'Recall (fraud)': f"{recall_score(y_test, preds, zero_division=0):.3f}",
'F1 (fraud)': f"{f1_score(y_test, preds, zero_division=0):.3f}",
'ROC-AUC': f"{roc_auc_score(y_test, proba):.3f}",
'PR-AUC': f"{average_precision_score(y_test, proba):.3f}",
})
df_results = pd.DataFrame(results)
print(df_results.to_string(index=False))
print("\nKey observations:")
print(" - Accuracy stays high across approaches (misleading metric)")
print(" - Recall jumps significantly with class_weight fix")
print(" - Threshold tuning finds the precision-recall sweet spot")
print(" - ROC-AUC and PR-AUC are same for same model (only threshold changes)")
ExercisesΒΆ
Extreme imbalance: Create a dataset with 999:1 ratio. Does
class_weight='balanced'still work? At what imbalance ratio does it break down?Metric selection: For a medical diagnostic test, is it worse to have false positives or false negatives? Set the threshold accordingly and report the tradeoffs.
SMOTE variants: If you have
imbalanced-learninstalled, compareSMOTE,ADASYN,BorderlineSMOTE, andRandomOverSampler. Which works best on a dataset with 10:1 imbalance?Undersampling: Instead of oversampling the minority, try undersampling the majority with
RandomUnderSampler. Compare to SMOTE β what are the tradeoffs?Business-driven threshold: Define a custom cost matrix: each false negative (missed fraud) costs $500, each false positive (blocked legit transaction) costs $10. Write a function to find the threshold that minimizes total expected cost.