Bayesian Modeling and Computation in Pyro - Chapter 5


Gabriel Stechschulte


July 23, 2022

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import arviz as az
import torch
import pyro
import pyro.distributions as dist
from pyro.distributions import constraints, transforms
from pyro.infer import Predictive, TracePredictive, NUTS, MCMC
from pyro.infer.autoguide import AutoLaplaceApproximation
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from patsy import dmatrix'ggplot')
plt.rcParams["figure.figsize"] = (9, 4)

5.5 - Fitting splines in Pyro

day = pd.read_csv('./data/Bike-Sharing-Dataset/day.csv')
hour = pd.read_csv('./data/Bike-Sharing-Dataset/hour.csv')
hour['cnt_std'] = hour['cnt'] / hour['cnt'].max()
sns.scatterplot(x=hour['hr'], y=hour['cnt_std'], alpha=0.1, color='grey')
plt.xlabel('Hour of Day (0-23)')
plt.title('Actual Bike Demand');

num_knots = 6
knot_list = torch.linspace(0, 23, num_knots + 2)[1:-1]

B = dmatrix(
    "bs(cnt_std, knots=knots, degree=3, include_intercept=True) - 1",
    {'cnt_std':, 'knots': knot_list[1:-1]}

B = torch.tensor(np.asarray(B)).float()
cnt_bikes = torch.tensor(hour['cnt_std'].values).float()
hour_bikes = torch.tensor(hour['hr'].values).reshape(-1, 1).float()

Splines Model - MCMC

def splines(design_matrix, count_bikes=None):

    N, P = design_matrix.shape

    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))

    with pyro.plate('knot_list', P):
        beta = pyro.sample('beta', dist.Normal(0., tau))

    mu = pyro.deterministic('mu', torch.matmul(beta, design_matrix.T))

    with pyro.plate('output', N):
        output = pyro.sample('y', dist.Normal(mu, sigma), obs=count_bikes)
    splines, (B, cnt_bikes), render_distributions=True

kernel = NUTS(splines)
mcmc_splines = MCMC(kernel, 500, 300), cnt_bikes)
prior_predictive = Predictive(splines, num_samples=500)(B, None)
spline_samples = mcmc_splines.get_samples(500)
splines_predictive = Predictive(splines, spline_samples)(B, None)

az_splines_pred = az.from_pyro(
    x=hour_bikes.flatten(), y=splines_predictive['mu'].mean(axis=0).T.flatten(),
    x=hour_bikes.flatten(), y=(B * spline_samples['beta'].mean(axis=0))[:, 5],

cnt_mu = splines_predictive['y'].mean(axis=0).T.flatten()
cnt_std = splines_predictive['y'].std(axis=0).T.flatten()

df = pd.DataFrame({
    'hr': hour['hr'].values,
    'cnt_scaled': hour['cnt_std'].values,
    'cnt_mu': cnt_mu,
    'cnt_std': cnt_std,
    'cnt_high': cnt_mu + cnt_std,
    'cnt_low': cnt_mu - cnt_std

df = df.sort_values(by=['hr'])

Figure 5.9

    x=df['hr'], y=df['cnt_mu'], color='blue')
    x=hour['hr'], y=hour['cnt_std'], color='grey', alpha=0.3
    x=df['hr'], y1=df['cnt_high'], y2=df['cnt_low'], color='grey',
plt.scatter(knot_list, np.zeros_like(knot_list), color='black')
plt.title('Cubic Splines using 6 Knots');

5.6 - Choosing knots and priors for splines

Bs = []
num_knots = [3, 6, 9, 12, 18]
for nk in num_knots:
    knot_list = torch.linspace(0, 24, nk+2)[1:-1]
    B = dmatrix(
        'bs(cnt, knots=knots, degree=3, include_intercept=True) - 1',
        {'cnt':, 'knots': knot_list[1:-1]}
    B = torch.tensor(np.asarray(B)).float()
inf_data = []
for B in Bs:

    mcmc_obj = MCMC(NUTS(splines), 500, 300), cnt_bikes)

    post_samples = mcmc_obj.get_samples(500)
    post_pred = Predictive(
        splines, post_samples
    )(B, None)

    az_obj = az.from_pyro(

# something is not right here
dict_cmp = {f"m_{k}k": v for k, v in zip(num_knots, inf_data)}
cmp =, ic='loo', var_name='y')
['m_18k', 'm_12k', 'm_9k', 'm_6k', 'm_3k']
rank elpd_loo p_loo elpd_diff weight se dse warning scale
m_3k 0 10575.715589 20.874493 0.000000 0.850828 129.781140 0.000000 False log
m_6k 1 10423.693963 14.518533 152.021626 0.000000 131.808422 19.387344 False log
m_9k 2 10094.535427 12.458113 481.180162 0.000000 133.635244 36.071652 False log
m_12k 3 9580.695289 8.562018 995.020300 0.000000 136.447458 53.955052 False log
m_18k 4 8600.222467 6.347949 1975.493122 0.149172 142.746908 81.838699 False log
colors = ['black', 'blue', 'grey', 'grey', 'black']
linestyle = ["-","-","--","--","-"]
linewidth = [1.5, 3, 1.5, 1.5, 3]

for ob, col, knots, ls, lw in zip(
    inf_data, colors, sorted(num_knots, reverse=True), linestyle, linewidth

    x=hour['hr'], y=ob['posterior_predictive']['y'][0].mean(axis=0), color=col,
    label=f'knots={knots}', linestyle=ls, linewidth=lw
    x=hour['hr'], y=hour['cnt_std'].values, color='lightgrey', alpha=0.5, edgecolor='grey'
    plt.title('Model fit with different number of knots');

5.6.1 Regularizing priors for splines

class GaussianRandomWalk(dist.TorchDistribution):
    has_rsample = True
    arg_constraints = {'scale': constraints.positive}
    support = constraints.real

    def __init__(self, scale, num_steps=1):
        self.scale = scale
        batch_shape, event_shape = scale.shape, torch.Size([num_steps])
        super(GaussianRandomWalk, self).__init__(batch_shape, event_shape)
    def rsample(self, sample_shape=torch.Size()):
        shape = sample_shape + self.batch_shape + self.event_shape
        walks = self.scale.new_empty(shape).normal_()
        return walks.cumsum(-1) * self.scale.unsqueeze(-1)
    def log_prob(self, x):
        init_prob = dist.Normal(self.scale.new_tensor(0.), self.scale).log_prob(x[..., 0])
        step_probs = dist.Normal(x[..., :-1], self.scale).log_prob(x[..., 1:])
        return init_prob + step_probs.sum(-1)
def splines_grw(design_matrix, count_bikes=None):

    N, P = design_matrix.shape

    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))

    with pyro.plate('knot_list', P):
        beta = pyro.sample('beta', GaussianRandomWalk(scale=tau, num_steps=14))

    mu = pyro.deterministic('mu', torch.matmul(beta, design_matrix.T))

    with pyro.plate('output', N):
        output = pyro.sample('y', dist.Normal(mu, sigma), obs=count_bikes)
num_knots = 12
knot_list = torch.linspace(0, 23, num_knots + 2)[1:-1]

B = dmatrix(
    "bs(cnt_std, knots=knots, degree=3, include_intercept=True) - 1",
    {'cnt_std':, 'knots': knot_list[1:-1]}

B = torch.tensor(np.asarray(B)).float()
cnt_bikes = torch.tensor(hour['cnt_std'].values).float()
hour_bikes = torch.tensor(hour['hr'].values).reshape(-1, 1).float()
splines_grw_mcmc = MCMC(NUTS(splines_grw), 500, 300), cnt_bikes)
splines_grw_samples = splines_grw_mcmc.get_samples(1000)
splines_grw_post_pred = Predictive(
)(B, None)
# mean -> mean b/c I take mean of GRW first, then mean of
# 1000 posterior samples
    x=hour['hr'], y=splines_grw_post_pred['y'].mean(axis=1).mean(axis=0), 
    color='blue', lw=2, label='splines_grw mean function'
    x=hour['hr'], y=hour['cnt_std'].values, 
    color='lightgrey', alpha=0.5, edgecolor='grey'
plt.title('Gaussian Random Walk Prior');

Modeling \(\text{CO}_2\) Uptake with Splines

plants_CO2 = pd.read_csv("./data/CO2_uptake.csv")
plant_names = plants_CO2.Plant.unique()
CO2_conc = plants_CO2.conc.values[:7]
CO2_concs = plants_CO2.conc.values
uptake = plants_CO2.uptake.values
index = range(12)
groups = len(index)
num_knots = 2
knot_list = np.linspace(CO2_conc[0], CO2_conc[-1], num_knots+2)[1:-1]

Bg = dmatrix(
    "bs(conc, knots=knots, degree=3, include_intercept=True) - 1",
    {"conc": CO2_concs, "knots": knot_list},

Bg = torch.tensor(np.asarray(Bg)).float()
uptake = torch.tensor(uptake).float()

Pooled Model - MCMC

def single_response(design_matrix, obs=None):

    N, P = design_matrix.shape
    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))

    with pyro.plate('coef', P):
        beta = pyro.sample('beta', dist.Normal(0., tau))

    ug = pyro.deterministic('ug', torch.matmul(beta, design_matrix.T))

    with pyro.plate('obs', N):
        up = pyro.sample('uptake', dist.Normal(ug, sigma), obs=obs)
    (Bg, uptake), 

sr_mcmc = MCMC(NUTS(single_response), 500, 300), uptake)
                mean       std    median      5.0%     95.0%     n_eff     r_hat
   beta[0]     12.21      1.82     12.27      9.20     15.06    261.26      1.00
   beta[1]     30.06      3.52     30.11     24.50     35.95    251.13      1.00
   beta[2]     30.06      6.05     29.81     20.25     39.69    193.86      1.00
   beta[3]     33.86      9.16     33.72     20.00     49.44    171.13      1.00
   beta[4]     25.27     22.92     24.96    -14.30     60.07    186.48      1.01
   beta[5]     33.41      2.03     33.29     30.06     36.65    395.37      1.00
     sigma      6.77      0.38      6.75      6.24      7.48    596.67      1.00
       tau     31.24     10.56     28.94     17.24     47.38    226.15      1.00

sr_samples = sr_mcmc.get_samples(1000)
sr_post_pred = Predictive(
)(Bg, None)
fig, axes = plt.subplots(4, 3, figsize=(10, 6), sharey=True, sharex=True)

for count, (idx, ax) in enumerate(zip(range(0, 84, 7), axes.ravel())):
    ax.plot(CO2_conc, uptake[idx:idx+7], '.', lw=1, color='black')
    ax.plot(CO2_conc, sr_post_pred['uptake'].mean(axis=0)[idx:idx+7], "k", alpha=0.5);
    az.plot_hdi(CO2_conc, sr_post_pred['uptake'][:,idx:idx+7], color="C2", smooth=False, ax=ax)

fig.text(0.4, -0.05, "CO2 concentration", size=18)
fig.text(-0.03, 0.4, "CO2 uptake", size=18, rotation=90);
num_knots = 2
knot_list = np.linspace(CO2_conc[0], CO2_conc[-1], num_knots+2)[1:-1]

Bi = dmatrix(
    "bs(conc, knots=knots, degree=3, include_intercept=True) - 1",
    {"conc": CO2_conc, "knots": knot_list},

Bi = torch.tensor(np.asarray(Bi)).float()

Mixed Effects Model - MCMC

def individual_response(design_matrix, groups, obs=None):

    N, P = design_matrix.size()
    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))
    beta = pyro.sample('beta', dist.Normal(0., tau).expand([P, groups]))
    ug = pyro.deterministic('ug', torch.matmul(design_matrix, beta))
    ug = ug[:, index].T.ravel()

    with pyro.plate('obs', ug.size(0)):
        up = pyro.sample('uptake', dist.Normal(ug, sigma), obs=obs)
ir_mcmc = MCMC(NUTS(individual_response), 500, 300), groups, uptake)
ir_samples = ir_mcmc.get_samples(1000)
ir_post_pred = Predictive(
)(Bi, groups, None)
fig, axes = plt.subplots(4, 3, figsize=(10, 6), sharey=True, sharex=True)

for count, (idx, ax) in enumerate(zip(range(0, 84, 7), axes.ravel())):
    ax.plot(CO2_conc, uptake[idx:idx+7], '.', lw=1, color='black')
    ax.plot(CO2_conc, ir_post_pred['uptake'].mean(axis=0)[idx:idx+7], "k", alpha=0.5);
    az.plot_hdi(CO2_conc, ir_post_pred['uptake'][:,idx:idx+7], color="C2", smooth=False, ax=ax)

fig.text(0.4, -0.05, "CO2 concentration", size=18)
fig.text(-0.03, 0.4, "CO2 uptake", size=18, rotation=90);
