Evaluation¶
ngboost-lightning provides diagnostic tools for assessing the quality of
probabilistic predictions. These are available in the ngboost_lightning.evaluation
module.
PIT (Probability Integral Transform)¶
For a well-calibrated model, the CDF values F(y) evaluated at the true
observations should follow a Uniform(0, 1) distribution. Deviations from
uniformity indicate miscalibration.
from ngboost_lightning.evaluation import pit_values, plot_pit_histogram
dist = reg.pred_dist(X_test)
pit = pit_values(dist, y_test)
# Visual check — histogram should be flat
plot_pit_histogram(pit)
Interpreting the PIT histogram:
- Flat (uniform) — well calibrated.
- U-shaped — underdispersed (prediction intervals too narrow).
- Hump-shaped — overdispersed (prediction intervals too wide).
- Skewed — systematic bias in location.
Calibration Curve¶
The calibration curve plots observed coverage against expected quantile levels. A perfectly calibrated model lies on the diagonal.
from ngboost_lightning.evaluation import (
calibration_regression,
calibration_error,
plot_calibration_curve,
)
obs, exp = calibration_regression(dist, y_test, bins=11)
plot_calibration_curve(obs, exp)
# Scalar summary: mean absolute deviation from diagonal
cal_err = calibration_error(dist, y_test)
Survival Evaluation¶
Concordance Index¶
The concordance index (C-index) measures discrimination — how well the model ranks patients by predicted risk. A value of 1.0 means perfect ranking; 0.5 means random.
from ngboost_lightning.evaluation import concordance_index
dist = surv.pred_dist(X_test)
c_index = concordance_index(dist, T_test, E_test)
Survival Calibration¶
Analogous to regression calibration, but uses the survival function:
from ngboost_lightning.evaluation import calibration_survival
obs, exp = calibration_survival(dist, T_test, E_test)
Plotting¶
Plot functions require matplotlib (optional dependency). They return
the matplotlib Axes object for further customization:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
plot_pit_histogram(pit, ax=axes[0])
plot_calibration_curve(obs, exp, ax=axes[1])
plt.tight_layout()
plt.show()
If no ax is provided, the functions create a new figure automatically.
Function Reference¶
| Function | Description |
|---|---|
pit_values(dist, y) |
PIT values F(y), shape [n_samples] |
calibration_regression(dist, y) |
(expected, observed) quantile calibration |
calibration_error(dist, y) |
Scalar mean absolute calibration error |
calibration_survival(dist, T, E) |
Survival calibration curve |
concordance_index(dist, T, E) |
C-index for survival discrimination |
plot_pit_histogram(pit) |
Histogram of PIT values |
plot_calibration_curve(obs, exp) |
Calibration curve plot |
See the API Reference for full signatures and details.