The right way to do predictive checks with observation-level random effects

Prediction
Statistics
brms
PyMC
Author

Erik J. Ringen

Published

November 16, 2024

Observation-level random effects (OLRE) are an effective way to model overdispersed count data (Harrison 2014). For example, OLREs relax the assumption of Poisson regression that the variance is equal to the mean by giving each observation \(i\) a “random intercept” (\(\nu_i\)):

\[ y_i \sim \text{Poisson}(\lambda_i)\] \[ \text{log}(\lambda_i) = b_0 + \nu_i\] \[ \nu_i \sim \mathcal{N} (0, \sigma) \]

By putting \(\nu\) inside the linear model, we smuggle a variance component (\(\sigma\)) into a distribution that otherwise has only a single rate parameter (this trick also works for the Binomial distribution (Harrison 2015)). OLREs are also used to capture residual correlations in multi-response models (Hadfield 2010)

Predictive checks are a routine form of model checking used to understand a model’s ability to represent the data. Unfortunately, most software for performing predictive checks will handle OLREs the wrong way by default, giving an inflated goodness-of-fit. In this post I will show you how to do it the right way, in either R + brms or Python + PyMC. Here I focus on Bayesian models, but the basic idea would also hold for checking frequentist models with software such as lme4 and glmmTMB.

The wrong way

In Bayesian workflow, we often perform posterior predictive checks, where draws from the posterior distribution are used to generate many synthetic replications of our dataset, denoted \(y_{\text{rep}}\), which are then compared to the observed values. Systematic discrepancies between the distribution of \(y_\text{rep}\) and the actual data indicate misspecification, and can suggest ways to improve our models.

To illustrate this idea, we’ll use the Oceanic toolkit complexity dataset from Michelle Kline and Robert Boyd (Kline and Boyd 2010). The response variable is the count of unique tools in a given Oceanic society (total_tools), which is predicted by the natural logarithm of population size (population). First we will fit this basic model, using priors from (McElreath 2020). For each society \(i\):

\[ \text{total\_tools}_i \sim \text{Poisson}(\lambda_i)\] \[ \text{log}(\lambda_i) = b_0 + b_{\text{pop}}\text{log}(\text{population}_{i})_z\] \[ b_0 \sim \mathcal{N}(3, 0.5)\] \[ b_{\text{pop}} \sim \mathcal{N}(0, 0.2)\]

Where log(population) has been standardized to have zero mean and unit variance. Then, we will run posterior predictive checks using some off-the-shelf convenience functions.

Code
library(brms)
library(dplyr)
library(ggplot2)
library(bayesplot)
library(tidybayes)
library(patchwork)
set.seed(123)

Kline <- read.csv("https://raw.githubusercontent.com/rmcelreath/rethinking/refs/heads/master/data/Kline.csv", sep=";")

Kline$log_pop_z <- scale(log(Kline$population)) # standardize

m_poisson <- brm(
    total_tools ~ 1 + log_pop_z,
    family = poisson(link = "log"),
    prior = prior(normal(3, 0.5), class = "Intercept") + 
        prior(normal(0, 0.2), class = "b"),
    chains = 1,
    data = Kline,
    seed = 123)

color_scheme_set("teal")
theme_set(theme_classic(base_size = 13))

brms::pp_check(m_poisson, type = "dens_overlay", ndraws = 200) + 
    theme(legend.position = "none") + 
    brms::pp_check(m_poisson, type = "intervals") + 
    plot_layout(guides = 'collect') + 
    theme_classic(base_size = 13) +
    plot_annotation(subtitle = "Basic Poisson PPC") 

Posterior predictive checks for basic Poisson model. (left) replicated and observed densities, (right) observation-level reps, with bars representing 50% and 90% credible intervals.
Code
import matplotlib.colors as cols
import matplotlib.pyplot as plt
plt.rcParams.update({
    'font.size': 30,
    'axes.titlesize': 32,         
    'axes.labelsize': 30,        
    'xtick.labelsize': 28,      
    'ytick.labelsize': 28, 
    'legend.fontsize': 32,
    'lines.linewidth': 1.5,
})

import pandas as pd
import pymc as pm
import numpy as np
from scipy import stats
from scipy.stats import gaussian_kde
import arviz as az

Kline = pd.read_csv("https://raw.githubusercontent.com/rmcelreath/rethinking/refs/heads/master/data/Kline.csv", sep=";")

Kline['log_pop_z'] = stats.zscore(np.log(Kline['population']))

with pm.Model() as m_poisson:
    # priors
    b0 = pm.Normal("Intercept", mu=3, sigma=0.5)
    b_pop = pm.Normal("slope", mu=0, sigma=0.2)
    # linear model
    log_lam = b0 + b_pop * Kline['log_pop_z']
    ## Poisson likelihood
    y = pm.Poisson("y", mu=pm.math.exp(log_lam), observed=Kline['total_tools'])

    idata = pm.sample(4000, chains=1, random_seed=123)
    pm.sample_posterior_predictive(idata, extend_inferencedata=True, random_seed=1);
Code


idata.observed_data['y'] = idata.observed_data['y'].astype(np.float64) # Convert observed data to float for visualization
idata.posterior_predictive['y'] = idata.posterior_predictive['y'].astype(np.float64)

# Define plot functions
def plot_ppc_dens(y, yrep, ax, ax_index, num_samples=200):
    yrep = yrep.values
    y = y.values 

    for i in range(num_samples):
        sample = yrep[0, i, :]
        kde_sample = gaussian_kde(sample)
        x_values = np.linspace(yrep.min(), yrep.max(), 200)
        ax[ax_index].plot(x_values, kde_sample(x_values), color=(0.0, 0.486, 0.486, 0.05))  # Use low alpha for transparency

    kde = gaussian_kde(y)
    x_values = np.linspace(yrep.min(), yrep.max(), 200)
    ax[ax_index].plot(x_values, kde(x_values), color="#007c7c", linewidth=6)

    ax[ax_index].set_xlabel('')
    ax[ax_index].set_ylabel('')
    ax[ax_index].margins(y=0)
    ax[ax_index].margins(x=0)
    ax[ax_index].spines['top'].set_visible(False)
    ax[ax_index].spines['right'].set_visible(False)

def plot_ppc_intervals(y, yrep, ax, ax_index):
    y = y.values
    yrep = yrep.stack(sample=("chain", "draw")).values

    median_predictions = np.median(yrep, axis=1)

    # Define x-axis values
    x = np.arange(len(y))

    intervals = [(25, 75), (5, 95)]
    colors = ['#007C7C', '#007C7C']
    labels = ['50% Interval', '90% Interval']

    for (low, high), color, label in zip(intervals, colors, labels):
        lower_bounds = np.percentile(yrep, low, axis=1)
        upper_bounds = np.percentile(yrep, high, axis=1)
        error_lower = median_predictions - lower_bounds
        error_upper = upper_bounds - median_predictions
        error = [error_lower, error_upper]

        ax[ax_index].errorbar(
        x,
        median_predictions,
        yerr=error,
        fmt='o',
        color='#007C7C',
        ecolor=color,
        elinewidth=8,
        capsize=0,
        label=label,
        alpha=0.1,
        markersize=20
    )

    # Overlay observed data points
    ax[ax_index].scatter(
        x,
        y,
        color='#007C7C',
        label='Observed Data',
        zorder=5,
        s = 150
    )

    # Customize the plot
    ax[ax_index].set_xlabel('Data point (index)')
    ax[ax_index].set_ylabel('')
    ax[ax_index].set_title('')
    ax[ax_index].legend(['y', 'yrep'], loc="upper left")
    ax[ax_index].spines['top'].set_visible(False)
    ax[ax_index].spines['right'].set_visible(False)


fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 13))

plot_ppc_dens(idata.observed_data['y'], idata.posterior_predictive['y'], axes, 0)
plot_ppc_intervals(idata.observed_data['y'], idata.posterior_predictive['y'], axes, 1)
axes[0].set_title('Basic Poisson PPC', loc = "left")

Posterior predictive checks for basic Poisson model. (left) replicated and observed densities, (right) observation-level reps, with bars representing 50% and 90% credible intervals.

We see some indications that the data are overdispersed, relative to the model. Namely, the distribution of the observed \(y\) (total_tools) appears “flat” compared to \(y_{\text{rep}}\) in the left-side plot. We can also see in the right-side plot that several observed values fall outside of the 90% credible intervals, suggesting that the model’s predictions are too precise. So, lets try adding an OLRE to capture this overdispersion. Here’s our updated model definition:

\[ \text{total\_tools}_i \sim \text{Poisson}(\lambda_i)\] \[ \text{log}(\lambda_i) = b_0 + \nu_i + b_{\text{pop}}\text{log}(\text{population}_{z_i})\] \[ b_0 \sim \mathcal{N}(3, 0.5)\] \[ b_{\text{pop}} \sim \mathcal{N}(0, 0.2)\] \[ \nu_i \sim \mathcal{N}(0, \sigma)\] \[ \sigma \sim \text{Exponential}(2) \]

Code
Kline$obs <- 1:nrow(Kline)

m_poisson_OLRE <- brm(
    total_tools ~ 1 + log_pop_z + (1|obs),
    family = poisson(link = "log"),
    prior = prior(normal(3, 0.5), class = "Intercept") + 
        prior(normal(0, 0.2), class = "b") + 
        prior(exponential(2), class = "sd"),
    chains = 1,
    control = list(adapt_delta = 0.95),
    data = Kline,
    seed=123,
    save_pars = save_pars(all = TRUE))

brms::pp_check(m_poisson_OLRE, type = "dens_overlay", ndraws = 200) + theme(legend.position = "none") + brms::pp_check(m_poisson_OLRE, type = "intervals") +
    plot_annotation(subtitle = "OLRE PPC: The wrong way") 

Code
obs_idx = np.arange(len(Kline))
coords = {"obs": obs_idx}

with pm.Model(coords = coords) as m_poisson_OLRE:
    # priors
    b0 = pm.Normal("Intercept", mu=3, sigma=0.5)
    b_pop = pm.Normal("slope", mu=0, sigma=0.2)
    nu_z = pm.Normal("nu_z", mu = 0, sigma = 1, dims = "obs")
    sigma = pm.Exponential("sigma", lam = 2)
    nu = pm.Deterministic("nu", nu_z*sigma)
    # linear model
    log_lam = b0 + nu[obs_idx] + b_pop * Kline['log_pop_z'] 
    ## Poisson likelihood
    y = pm.Poisson("y", mu=pm.math.exp(log_lam), observed=Kline['total_tools'])

    idata_OLRE = pm.sample(4000, chains = 1, target_accept = 0.95, random_seed=123)
    pm.sample_posterior_predictive(idata_OLRE, extend_inferencedata=True, random_seed=1);
Code

idata_OLRE.observed_data['y'] = idata_OLRE.observed_data['y'].astype(np.float64) # Convert observed data to float for visualization
idata_OLRE.posterior_predictive['y'] = idata_OLRE.posterior_predictive['y'].astype(np.float64)

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 13))

plot_ppc_dens(idata_OLRE.observed_data['y'], idata_OLRE.posterior_predictive['y'], axes, 0)
plot_ppc_intervals(idata_OLRE.observed_data['y'], idata_OLRE.posterior_predictive['y'], axes, 1)
axes[0].set_title('OLRE PPC: The wrong way', loc = "left")

Incorrect posterior predictive checks for OLRE Poisson model, using off-the-shelf convenience functions. (left) replicated and observed densities, (right) observation-level reps, with bars representing 50% and 90% credible intervals.

Looks good, right? Sadly, this is a little too good to be true. We have misled ourselves. To see why, let’s examine the posterior distribution of the OLREs in relation to the observed data:

Code
m_poisson_OLRE |> 
    spread_draws(r_obs[obs]) |> 
    median_qi(estimate = r_obs, .width = 0.9) |> 
    left_join(Kline, by = "obs") |> 
    ggplot(aes(x = estimate, y = log(total_tools), xmin = .lower, xmax = .upper)) +
    geom_pointinterval() + 
    theme_classic(base_size = 15) + 
    labs(x = expression(nu), y = "log(total tools)", title = "y ~ Fitted OLRE")

Fitted nu parameters as a function of the natural log of the response, total tools. Bars represent 90% credible intervals.
Code
nu_fitted_summary = az.summary(idata_OLRE, var_names=["^nu(?!_z).*"], filter_vars="regex", hdi_prob=0.9)

fig = plt.figure(figsize=(20, 13))
plt.hlines(y=np.log(Kline['total_tools']), xmin=nu_fitted_summary['hdi_5%'], xmax = nu_fitted_summary['hdi_95%'], color="black", linewidth=6)
plt.scatter(y=np.log(Kline['total_tools']), x=nu_fitted_summary['mean'], color="black", s = 150)

plt.title('y ~ Fitted OLRE', loc = "left")
plt.xlabel(r'$\nu$')
plt.ylabel('log(total tools)')

Fitted nu parameters as a function of the natural log of the response, total tools. Bars represent 90% credible intervals.

This plot shows us that the OLREs are positively correlated with the values of the observed data. Why? These parameters are doing exactly what they are supposed to do: capture excess dispersion in the data by learning which points are higher or lower than we would expect, given their population size. To understand why this is a problem, consider out-of-sample prediction: when generating \(y_\text{rep}\) for a new observation, we don’t know \(y_{\text{test}}\) in advance, so the OLRE should convey no information about it. In our naive predictive check, we have mistakenly treated \(\nu\) as fixed, when really it should be replicated along with \(y_\text{rep}\), akin to \(\epsilon\) in a linear regression. Generating \(y_\text{rep}\) this way is referred to as “mixed replication”, because we leave the hyperparameter \(\sigma\) fixed but replicate each random effect parameter (Gelman, Meng, and Stern 1996).

The right way

The way out of this is straightforward. All we have to do is replace the fitted OLREs with new levels, denoted \(\nu_{\text{rep}}\), which are generated using posterior draws of the observation-level standard deviation \(\sigma\).

Code
yrep_OLRE <- m_poisson_OLRE |> 
    posterior_predict(newdata = Kline |> 
    mutate(obs = paste("OLRE_rep", 1:n())),
     allow_new_levels = TRUE,
     sample_new_levels = "gaussian")

bayesplot::ppc_dens_overlay(Kline$total_tools, yrep_OLRE[1:100,]) + theme(legend.position = "none") + bayesplot::ppc_intervals(Kline$total_tools, yrep_OLRE) +
    plot_annotation(subtitle = "OLRE PPC: The right way") 

Correct posterior predictive checks for OLRE Poisson model, sampling new levels of nu (nu_rep). (left) replicated and observed densities, (right) observation-level reps, with bars representing 50% and 90% credible intervals.
Code
# resample both y and nu_z
pred_yrep = pm.sample_posterior_predictive(idata_OLRE, m_poisson_OLRE, predictions=True, extend_inferencedata=False, var_names = ['y', 'nu_z'], random_seed=2)
Code

fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 13))

plot_ppc_dens(idata.observed_data['y'], pred_yrep.predictions['y'], axes, 0)
plot_ppc_intervals(idata.observed_data['y'], pred_yrep.predictions['y'], axes, 1)
axes[0].set_title('OLRE PPC: The right way', loc = "left")

Correct posterior predictive checks for OLRE Poisson model, sampling new levels of nu (nu_rep). (left) replicated and observed densities, (right) observation-level reps, with bars representing 50% and 90% credible intervals.

Notice that, unlike our first predictive check with no OLRE, the credible intervals of \(y_{\text{rep}}\) all contain the observed values of y. But unlike our (wrong) second predictive check, the predictions do not conform so closely to the observed values, because each \(\nu_{\text{rep}}\) is independent of \(y\). This provides us with a more realistic picture of our model’s fit. So, why do most posterior predictive functions treat OLREs the wrong way by default? Because the software doesn’t know whether \(\nu\) is an OLRE or instead a parameter that should be fixed across replications, like random effects for group differences. The latter is more common, so the default is sensible–but not necessarily safe.

There is a caveat: even if we do it the “right way”, all posterior predictive checks are overly optimistic for out-of-sample data because they use the same data for fitting and evaluation. A model that performs well in these checks might actually have poor generalization to new data due to overfitting. This issue is not specific to OLREs, but in the final section I’ll show you how to address overfitting in predictive checks.

An even better way?

Leave-one-out cross validation (LOOCV) provides a more honest assessment of predictive accuracy by holding out one observation at a time as a test point. Since the model never sees the held-out observation during fitting, these model checks will reflect true predictive performance rather than retrodiction of the sample. In the code below we will do exact LOOCV, refitting the model and making predictions for the left-out point \(N = 10\) times.

Code
yrep_loo <- matrix(NA, nrow = ndraws(m_poisson_OLRE), ncol = nrow(Kline))

for (i in 1:nrow(Kline)) {
    Kline_train <- Kline[-i,]
    Kline_test <- Kline[i,]
    # recompute z-scores to avoid leakage
    mean_log_pop <- mean(log(Kline_train$population))
    sd_log_pop <- sd(log(Kline_train$population))

    model_loo <- update(m_poisson_OLRE, newdata = Kline_train |> mutate(log_pop_z = (log(population) - mean_log_pop) / sd_log_pop), seed = 123)
    yrep_loo[,i] <- posterior_predict(model_loo, newdata = Kline_test |> mutate(log_pop_z = (log(population) - mean_log_pop) / sd_log_pop), allow_new_levels = TRUE, sample_new_levels = "gaussian")
}

bayesplot::ppc_dens_overlay(Kline$total_tools, yrep_loo[1:100,]) + theme(legend.position = "none") + bayesplot::ppc_intervals(Kline$total_tools, yrep_loo) +
    plot_annotation(subtitle = "OLRE PPC-LOO") 

LOOCV posterior predictive checks for OLRE Poisson model. Observation-level reps, with bars representing 50% and 90% credible intervals.
Code
def model_factory(train, test):
    obs_idx = np.arange(len(train))
    coords = {"obs_id": obs_idx}

    # recompute z-scores to avoid leakage
    mean_log_pop = np.mean(np.log(train['population']))
    sd_log_pop = np.std(np.log(train['population']))

    train['log_pop_z'] = stats.zscore(np.log(train['population']))

    with pm.Model(coords = coords) as model:
        log_pop_z = pm.Data("log_pop_z", train['log_pop_z'], dims = "obs_id")
        y = pm.Data("y", train['total_tools'], dims = "obs_id")

        b0 = pm.Normal("Intercept", mu=3, sigma=0.5)
        b_pop = pm.Normal("slope", mu=0, sigma=0.2)
        nu_z = pm.Normal("nu_z", mu = 0, sigma = 1, dims = "obs_id")
        sigma = pm.Exponential("sigma", lam = 2)
        nu = pm.Deterministic("nu", nu_z*sigma, dims = "obs_id")
        # linear model
        log_lam = pm.Deterministic("log_lam", b0 + nu + b_pop * log_pop_z, dims = "obs_id")
        ## Poisson likelihood
        y = pm.Poisson("y_obs", mu=pm.math.exp(log_lam), observed=y, dims = "obs_id")
    
        idata_loo = pm.sample(1000, chains = 1, target_accept = 0.99, random_seed=99)
    
    with model:
        pm.set_data({"log_pop_z": (np.log(test['population']) - mean_log_pop)/sd_log_pop, "y": test['total_tools']}, coords={"obs_id": np.array([len(train)+1])})
        
        idata_loo.extend(pm.sample_posterior_predictive(idata_loo, random_seed=2))

    return idata_loo.posterior_predictive['y_obs']

test_preds = []

for i in range(len(Kline)):
    train = Kline.drop(index = Kline.index[i])
    test = Kline.iloc[[i]]
    yrep_loo = model_factory(train, test)
    test_preds.append(yrep_loo)
Code

import xarray as xr

loo_preds_combined = xr.concat(test_preds, dim='y_obs').transpose('chain', 'draw', 'y_obs', 'obs_id').squeeze('obs_id')
        
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 13))
plot_ppc_dens(idata.observed_data['y'], loo_preds_combined, axes, 0)
plot_ppc_intervals(idata.observed_data['y'], loo_preds_combined, axes, 1)
axes[0].set_title('OLRE PPC-LOO', loc = "left")

Exact LOOCV posterior predictive checks for OLRE Poisson model. Observation-level reps, with bars representing 50% and 90% credible intervals.

This predictive check is a bit less optimistic, but it still looks better than the first Poisson model we fit. This suggests that the OLRE is helpful and we have not overfit too badly. For large datasets, it becomes infeasible to refit the model for each observation, so one might turn to K-fold cross validation or Pareto-smoothed importance sampling as an approximation (PSIS-LOO) (Vehtari, Gelman, and Gabry 2017). Note that PSIS may not reliable for these types of models, and one should instead integrate out the OLREs using adaptive quadrature.

Reproducible environment

R session info

Conda environment

References

Gelman, Andrew, Xiao-Li Meng, and Hal Stern. 1996. “Posterior Predictive Assessment of Model Fitness via Realized Discrepancies.” Statistica Sinica, 733–60.
Hadfield, Jarrod D. 2010. “MCMC Methods for Multi-Response Generalized Linear Mixed Models: The MCMCglmm r Package.” Journal of Statistical Software 33: 1–22.
Harrison, Xavier A. 2014. “Using Observation-Level Random Effects to Model Overdispersion in Count Data in Ecology and Evolution.” PeerJ 2: e616.
———. 2015. “A Comparison of Observation-Level Random Effect and Beta-Binomial Models for Modelling Overdispersion in Binomial Data in Ecology & Evolution.” PeerJ 3: e1114.
Kline, Michelle A, and Robert Boyd. 2010. “Population Size Predicts Technological Complexity in Oceania.” Proceedings of the Royal Society B: Biological Sciences 277 (1693): 2559–64.
McElreath, Richard. 2020. “Statistical Rethinking: A Bayesian Course with Examples in r and Stan.” In, 2nd ed., 350. Chapman; Hall/CRC.
Vehtari, Aki, Andrew Gelman, and Jonah Gabry. 2017. “Practical Bayesian Model Evaluation Using Leave-One-Out Cross-Validation and WAIC.” Statistics and Computing 27: 1413–32.