Bayesian Modeling and Computation in Pyro - Chapter 5

bayesian-statistics
python
Author

Gabriel Stechschulte

Published

July 23, 2022

Code
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import arviz as az
import torch
import pyro
import pyro.distributions as dist
from pyro.distributions import constraints, transforms
from pyro.infer import Predictive, TracePredictive, NUTS, MCMC
from pyro.infer.autoguide import AutoLaplaceApproximation
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from patsy import dmatrix
plt.style.use('ggplot')
plt.rcParams["figure.figsize"] = (9, 4)

5.5 - Fitting splines in Pyro

day = pd.read_csv('./data/Bike-Sharing-Dataset/day.csv')
hour = pd.read_csv('./data/Bike-Sharing-Dataset/hour.csv')
hour['cnt_std'] = hour['cnt'] / hour['cnt'].max()
sns.scatterplot(x=hour['hr'], y=hour['cnt_std'], alpha=0.1, color='grey')
plt.ylabel('Count')
plt.xlabel('Hour of Day (0-23)')
plt.title('Actual Bike Demand');

num_knots = 6
knot_list = torch.linspace(0, 23, num_knots + 2)[1:-1]

B = dmatrix(
    "bs(cnt_std, knots=knots, degree=3, include_intercept=True) - 1",
    {'cnt_std': hour.hr.values, 'knots': knot_list[1:-1]}
)

B = torch.tensor(np.asarray(B)).float()
cnt_bikes = torch.tensor(hour['cnt_std'].values).float()
hour_bikes = torch.tensor(hour['hr'].values).reshape(-1, 1).float()

Splines Model - MCMC

def splines(design_matrix, count_bikes=None):

    N, P = design_matrix.shape

    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))

    with pyro.plate('knot_list', P):
        beta = pyro.sample('beta', dist.Normal(0., tau))

    mu = pyro.deterministic('mu', torch.matmul(beta, design_matrix.T))

    with pyro.plate('output', N):
        output = pyro.sample('y', dist.Normal(mu, sigma), obs=count_bikes)
pyro.render_model(
    splines, (B, cnt_bikes), render_distributions=True
)

kernel = NUTS(splines)
mcmc_splines = MCMC(kernel, 500, 300)
mcmc_splines.run(B, cnt_bikes)
Sample: 100%|██████████| 800/800 [00:12, 62.91it/s, step size=2.93e-01, acc. prob=0.887]
prior_predictive = Predictive(splines, num_samples=500)(B, None)
spline_samples = mcmc_splines.get_samples(500)
splines_predictive = Predictive(splines, spline_samples)(B, None)

az_splines_pred = az.from_pyro(
    prior=prior_predictive,
    posterior=mcmc_splines, 
    posterior_predictive=splines_predictive
    )
sns.lineplot(
    x=hour_bikes.flatten(), y=splines_predictive['mu'].mean(axis=0).T.flatten(),
    color='black'
    )
sns.lineplot(
    x=hour_bikes.flatten(), y=(B * spline_samples['beta'].mean(axis=0))[:, 5],
    linestyle='--'
    );

cnt_mu = splines_predictive['y'].mean(axis=0).T.flatten()
cnt_std = splines_predictive['y'].std(axis=0).T.flatten()

df = pd.DataFrame({
    'hr': hour['hr'].values,
    'cnt_scaled': hour['cnt_std'].values,
    'cnt_mu': cnt_mu,
    'cnt_std': cnt_std,
    'cnt_high': cnt_mu + cnt_std,
    'cnt_low': cnt_mu - cnt_std
})

df = df.sort_values(by=['hr'])

Figure 5.9

sns.lineplot(
    x=df['hr'], y=df['cnt_mu'], color='blue')
sns.scatterplot(
    x=hour['hr'], y=hour['cnt_std'], color='grey', alpha=0.3
    )
plt.fill_between(
    x=df['hr'], y1=df['cnt_high'], y2=df['cnt_low'], color='grey',
    alpha=0.3
    )
plt.scatter(knot_list, np.zeros_like(knot_list), color='black')
plt.title('Cubic Splines using 6 Knots');

5.6 - Choosing knots and priors for splines

Bs = []
num_knots = [3, 6, 9, 12, 18]
for nk in num_knots:
    knot_list = torch.linspace(0, 24, nk+2)[1:-1]
    B = dmatrix(
        'bs(cnt, knots=knots, degree=3, include_intercept=True) - 1',
        {'cnt': hour.hr.values, 'knots': knot_list[1:-1]}
    )
    B = torch.tensor(np.asarray(B)).float()
    Bs.append(B)
inf_data = []
for B in Bs:

    mcmc_obj = MCMC(NUTS(splines), 500, 300)
    mcmc_obj.run(B, cnt_bikes)

    post_samples = mcmc_obj.get_samples(500)
    post_pred = Predictive(
        splines, post_samples
    )(B, None)

    az_obj = az.from_pyro(
        posterior=mcmc_obj,
        posterior_predictive=post_pred
    )

    inf_data.append(az_obj)
Sample: 100%|██████████| 800/800 [00:16, 48.58it/s, step size=2.62e-01, acc. prob=0.924]
Sample: 100%|██████████| 800/800 [00:13, 58.92it/s, step size=2.90e-01, acc. prob=0.888]
Sample: 100%|██████████| 800/800 [00:14, 55.82it/s, step size=2.67e-01, acc. prob=0.910]
Sample: 100%|██████████| 800/800 [00:13, 58.91it/s, step size=3.04e-01, acc. prob=0.886]
Sample: 100%|██████████| 800/800 [00:17, 45.97it/s, step size=2.59e-01, acc. prob=0.890]
# something is not right here
dict_cmp = {f"m_{k}k": v for k, v in zip(num_knots, inf_data)}
cmp = az.compare(dict_cmp, ic='loo', var_name='y')
cmp
['m_18k', 'm_12k', 'm_9k', 'm_6k', 'm_3k']
rank elpd_loo p_loo elpd_diff weight se dse warning scale
m_3k 0 10575.715589 20.874493 0.000000 0.850828 129.781140 0.000000 False log
m_6k 1 10423.693963 14.518533 152.021626 0.000000 131.808422 19.387344 False log
m_9k 2 10094.535427 12.458113 481.180162 0.000000 133.635244 36.071652 False log
m_12k 3 9580.695289 8.562018 995.020300 0.000000 136.447458 53.955052 False log
m_18k 4 8600.222467 6.347949 1975.493122 0.149172 142.746908 81.838699 False log
colors = ['black', 'blue', 'grey', 'grey', 'black']
linestyle = ["-","-","--","--","-"]
linewidth = [1.5, 3, 1.5, 1.5, 3]

for ob, col, knots, ls, lw in zip(
    inf_data, colors, sorted(num_knots, reverse=True), linestyle, linewidth
    ):

    sns.lineplot(
    x=hour['hr'], y=ob['posterior_predictive']['y'][0].mean(axis=0), color=col,
    label=f'knots={knots}', linestyle=ls, linewidth=lw
    )
    sns.scatterplot(
    x=hour['hr'], y=hour['cnt_std'].values, color='lightgrey', alpha=0.5, edgecolor='grey'
    )
    plt.title('Model fit with different number of knots');

5.6.1 Regularizing priors for splines

class GaussianRandomWalk(dist.TorchDistribution):
    has_rsample = True
    arg_constraints = {'scale': constraints.positive}
    support = constraints.real

    def __init__(self, scale, num_steps=1):
        self.scale = scale
        batch_shape, event_shape = scale.shape, torch.Size([num_steps])
        super(GaussianRandomWalk, self).__init__(batch_shape, event_shape)
    
    def rsample(self, sample_shape=torch.Size()):
        shape = sample_shape + self.batch_shape + self.event_shape
        walks = self.scale.new_empty(shape).normal_()
        return walks.cumsum(-1) * self.scale.unsqueeze(-1)
    
    def log_prob(self, x):
        init_prob = dist.Normal(self.scale.new_tensor(0.), self.scale).log_prob(x[..., 0])
        step_probs = dist.Normal(x[..., :-1], self.scale).log_prob(x[..., 1:])
        return init_prob + step_probs.sum(-1)
def splines_grw(design_matrix, count_bikes=None):

    N, P = design_matrix.shape

    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))

    with pyro.plate('knot_list', P):
        beta = pyro.sample('beta', GaussianRandomWalk(scale=tau, num_steps=14))

    mu = pyro.deterministic('mu', torch.matmul(beta, design_matrix.T))

    with pyro.plate('output', N):
        output = pyro.sample('y', dist.Normal(mu, sigma), obs=count_bikes)
num_knots = 12
knot_list = torch.linspace(0, 23, num_knots + 2)[1:-1]

B = dmatrix(
    "bs(cnt_std, knots=knots, degree=3, include_intercept=True) - 1",
    {'cnt_std': hour.hr.values, 'knots': knot_list[1:-1]}
)

B = torch.tensor(np.asarray(B)).float()
cnt_bikes = torch.tensor(hour['cnt_std'].values).float()
hour_bikes = torch.tensor(hour['hr'].values).reshape(-1, 1).float()
splines_grw_mcmc = MCMC(NUTS(splines_grw), 500, 300)
splines_grw_mcmc.run(B, cnt_bikes)
Sample: 100%|██████████| 800/800 [01:48,  7.40it/s, step size=1.51e-01, acc. prob=0.878]
splines_grw_mcmc.summary()

                 mean       std    median      5.0%     95.0%     n_eff     r_hat
  beta[0,0]      0.06      0.00      0.06      0.05      0.06    793.18      1.00
  beta[0,1]     -0.00      0.01     -0.00     -0.02      0.02    543.60      1.00
  beta[0,2]      0.07      0.01      0.07      0.05      0.08    539.25      1.00
  beta[0,3]     -0.07      0.01     -0.07     -0.08     -0.05    593.26      1.00
  beta[0,4]      0.32      0.01      0.32      0.31      0.34    598.96      1.00
  beta[0,5]      0.30      0.01      0.30      0.29      0.32    647.04      1.00
  beta[0,6]      0.10      0.01      0.10      0.09      0.12    580.51      1.00
  beta[0,7]      0.34      0.01      0.34      0.33      0.36    588.65      1.00
  beta[0,8]      0.19      0.01      0.19      0.18      0.20    556.40      1.00
  beta[0,9]      0.31      0.01      0.31      0.30      0.32    542.28      1.00
 beta[0,10]      0.56      0.01      0.56      0.55      0.58    614.87      1.00
 beta[0,11]      0.11      0.01      0.11      0.09      0.13    651.78      1.00
 beta[0,12]      0.19      0.01      0.19      0.17      0.20    619.13      1.00
 beta[0,13]      0.09      0.01      0.09      0.08      0.10    684.85      1.00
  beta[1,0]      0.06      0.01      0.06      0.05      0.07    684.00      1.00
  beta[1,1]     -0.00      0.01     -0.00     -0.02      0.01    546.69      1.00
  beta[1,2]      0.07      0.01      0.07      0.05      0.09    424.07      1.01
  beta[1,3]     -0.07      0.01     -0.07     -0.08     -0.05    269.08      1.02
  beta[1,4]      0.32      0.01      0.32      0.31      0.33    414.38      1.02
  beta[1,5]      0.30      0.01      0.30      0.29      0.32    618.06      1.01
  beta[1,6]      0.10      0.01      0.10      0.09      0.12    563.03      1.00
  beta[1,7]      0.34      0.01      0.34      0.33      0.36    422.95      1.00
  beta[1,8]      0.19      0.01      0.19      0.18      0.20    553.61      1.00
  beta[1,9]      0.31      0.01      0.31      0.29      0.32    468.49      1.00
 beta[1,10]      0.56      0.01      0.56      0.55      0.58    446.50      1.00
 beta[1,11]      0.11      0.01      0.11      0.10      0.13    373.58      1.00
 beta[1,12]      0.19      0.01      0.19      0.17      0.20    500.06      1.00
 beta[1,13]      0.09      0.01      0.09      0.08      0.10    610.37      1.00
  beta[2,0]      0.06      0.01      0.06      0.05      0.07    564.38      1.00
  beta[2,1]     -0.00      0.01     -0.00     -0.02      0.01    599.84      1.00
  beta[2,2]      0.07      0.01      0.07      0.05      0.09    452.10      1.00
  beta[2,3]     -0.07      0.01     -0.07     -0.08     -0.05    550.86      1.00
  beta[2,4]      0.32      0.01      0.32      0.31      0.34    599.10      1.00
  beta[2,5]      0.30      0.01      0.30      0.29      0.32    688.41      1.00
  beta[2,6]      0.10      0.01      0.10      0.09      0.11    497.99      1.00
  beta[2,7]      0.34      0.01      0.34      0.33      0.36    439.54      1.00
  beta[2,8]      0.19      0.01      0.19      0.17      0.20    695.93      1.00
  beta[2,9]      0.31      0.01      0.31      0.29      0.32    728.54      1.00
 beta[2,10]      0.56      0.01      0.56      0.55      0.58    574.70      1.00
 beta[2,11]      0.11      0.01      0.11      0.09      0.13    543.40      1.00
 beta[2,12]      0.19      0.01      0.19      0.17      0.20    512.70      1.00
 beta[2,13]      0.09      0.01      0.09      0.08      0.10    635.21      1.00
  beta[3,0]      0.06      0.01      0.06      0.05      0.07    598.90      1.00
  beta[3,1]     -0.00      0.01     -0.00     -0.02      0.01    480.85      1.00
  beta[3,2]      0.07      0.01      0.07      0.05      0.08    540.00      1.00
  beta[3,3]     -0.07      0.01     -0.07     -0.08     -0.05    461.03      1.00
  beta[3,4]      0.32      0.01      0.32      0.31      0.33    489.37      1.00
  beta[3,5]      0.30      0.01      0.30      0.29      0.32    484.36      1.00
  beta[3,6]      0.10      0.01      0.10      0.09      0.12    591.71      1.00
  beta[3,7]      0.34      0.01      0.34      0.33      0.36    468.37      1.00
  beta[3,8]      0.19      0.01      0.19      0.18      0.20    325.26      1.00
  beta[3,9]      0.31      0.01      0.31      0.30      0.32    618.83      1.00
 beta[3,10]      0.56      0.01      0.56      0.55      0.58    502.47      1.00
 beta[3,11]      0.11      0.01      0.11      0.09      0.12    542.12      1.00
 beta[3,12]      0.19      0.01      0.19      0.17      0.20    576.58      1.00
 beta[3,13]      0.09      0.01      0.09      0.08      0.09    709.30      1.00
  beta[4,0]      0.06      0.00      0.06      0.05      0.07    785.06      1.00
  beta[4,1]     -0.00      0.01     -0.00     -0.02      0.01    604.06      1.00
  beta[4,2]      0.07      0.01      0.07      0.05      0.09    493.62      1.00
  beta[4,3]     -0.07      0.01     -0.07     -0.08     -0.05    511.71      1.00
  beta[4,4]      0.32      0.01      0.32      0.31      0.34    497.03      1.00
  beta[4,5]      0.30      0.01      0.30      0.29      0.31    657.27      1.00
  beta[4,6]      0.10      0.01      0.10      0.09      0.12    577.98      1.00
  beta[4,7]      0.34      0.01      0.34      0.33      0.36    514.37      1.00
  beta[4,8]      0.19      0.01      0.19      0.18      0.20    470.99      1.00
  beta[4,9]      0.31      0.01      0.31      0.29      0.32    669.33      1.00
 beta[4,10]      0.56      0.01      0.56      0.55      0.58    540.58      1.00
 beta[4,11]      0.11      0.01      0.11      0.09      0.12    502.99      1.00
 beta[4,12]      0.19      0.01      0.19      0.17      0.20    417.16      1.00
 beta[4,13]      0.09      0.00      0.09      0.08      0.09    592.79      1.00
  beta[5,0]      0.06      0.01      0.06      0.05      0.06    841.34      1.00
  beta[5,1]     -0.00      0.01     -0.00     -0.02      0.01    598.37      1.00
  beta[5,2]      0.07      0.01      0.07      0.05      0.08    393.97      1.00
  beta[5,3]     -0.07      0.01     -0.07     -0.08     -0.05    470.53      1.00
  beta[5,4]      0.32      0.01      0.32      0.31      0.33    519.00      1.00
  beta[5,5]      0.30      0.01      0.30      0.29      0.32    454.97      1.00
  beta[5,6]      0.10      0.01      0.10      0.09      0.11    467.47      1.00
  beta[5,7]      0.34      0.01      0.34      0.33      0.36    453.50      1.00
  beta[5,8]      0.19      0.01      0.19      0.17      0.20    476.61      1.00
  beta[5,9]      0.31      0.01      0.31      0.29      0.32    468.98      1.00
 beta[5,10]      0.56      0.01      0.56      0.55      0.58    412.85      1.00
 beta[5,11]      0.11      0.01      0.11      0.09      0.13    417.32      1.00
 beta[5,12]      0.19      0.01      0.19      0.17      0.20    487.35      1.01
 beta[5,13]      0.09      0.00      0.09      0.08      0.10    855.13      1.00
  beta[6,0]      0.06      0.00      0.06      0.05      0.06    469.53      1.00
  beta[6,1]     -0.00      0.01     -0.00     -0.02      0.01    490.03      1.00
  beta[6,2]      0.07      0.01      0.07      0.05      0.08    412.12      1.00
  beta[6,3]     -0.07      0.01     -0.07     -0.08     -0.05    547.10      1.00
  beta[6,4]      0.32      0.01      0.32      0.31      0.34    513.47      1.00
  beta[6,5]      0.30      0.01      0.30      0.29      0.32    527.11      1.00
  beta[6,6]      0.10      0.01      0.10      0.09      0.11    568.13      1.00
  beta[6,7]      0.34      0.01      0.34      0.33      0.36    621.95      1.00
  beta[6,8]      0.19      0.01      0.19      0.18      0.20    611.84      1.00
  beta[6,9]      0.31      0.01      0.31      0.30      0.32    591.34      1.00
 beta[6,10]      0.56      0.01      0.56      0.55      0.58    485.72      1.00
 beta[6,11]      0.11      0.01      0.11      0.10      0.13    454.76      1.00
 beta[6,12]      0.19      0.01      0.19      0.17      0.20    418.77      1.00
 beta[6,13]      0.09      0.00      0.09      0.08      0.10    635.51      1.00
  beta[7,0]      0.06      0.00      0.06      0.05      0.07    633.33      1.00
  beta[7,1]     -0.00      0.01     -0.00     -0.02      0.01    574.07      1.00
  beta[7,2]      0.07      0.01      0.07      0.05      0.08    519.25      1.00
  beta[7,3]     -0.07      0.01     -0.07     -0.08     -0.05    619.73      1.00
  beta[7,4]      0.32      0.01      0.32      0.31      0.34    623.32      1.00
  beta[7,5]      0.30      0.01      0.30      0.29      0.32    487.51      1.00
  beta[7,6]      0.10      0.01      0.10      0.09      0.12    521.75      1.00
  beta[7,7]      0.34      0.01      0.34      0.33      0.36    528.14      1.00
  beta[7,8]      0.19      0.01      0.19      0.18      0.20    536.29      1.00
  beta[7,9]      0.31      0.01      0.31      0.30      0.32    474.66      1.01
 beta[7,10]      0.56      0.01      0.56      0.55      0.58    444.83      1.01
 beta[7,11]      0.11      0.01      0.11      0.10      0.13    578.87      1.00
 beta[7,12]      0.18      0.01      0.19      0.17      0.20    636.52      1.00
 beta[7,13]      0.09      0.00      0.09      0.08      0.09    726.01      1.00
  beta[8,0]      0.06      0.00      0.06      0.05      0.06    759.45      1.00
  beta[8,1]     -0.00      0.01     -0.00     -0.02      0.02    451.96      1.00
  beta[8,2]      0.07      0.01      0.07      0.05      0.09    556.69      1.00
  beta[8,3]     -0.07      0.01     -0.07     -0.08     -0.05    587.00      1.00
  beta[8,4]      0.32      0.01      0.32      0.31      0.34    432.25      1.01
  beta[8,5]      0.30      0.01      0.30      0.29      0.32    497.88      1.00
  beta[8,6]      0.10      0.01      0.10      0.09      0.12    493.79      1.00
  beta[8,7]      0.34      0.01      0.34      0.33      0.36    473.27      1.00
  beta[8,8]      0.19      0.01      0.19      0.18      0.20    511.87      1.00
  beta[8,9]      0.31      0.01      0.31      0.30      0.32    560.39      1.00
 beta[8,10]      0.56      0.01      0.56      0.55      0.57    643.56      1.00
 beta[8,11]      0.11      0.01      0.11      0.09      0.13    615.96      1.00
 beta[8,12]      0.19      0.01      0.19      0.17      0.20    789.36      1.00
 beta[8,13]      0.09      0.00      0.09      0.08      0.10    722.99      1.00
  beta[9,0]      0.06      0.01      0.06      0.05      0.07    705.95      1.00
  beta[9,1]     -0.00      0.01     -0.00     -0.02      0.02    539.48      1.01
  beta[9,2]      0.07      0.01      0.07      0.05      0.09    729.97      1.01
  beta[9,3]     -0.07      0.01     -0.07     -0.08     -0.05    606.84      1.01
  beta[9,4]      0.32      0.01      0.32      0.31      0.33    533.78      1.00
  beta[9,5]      0.30      0.01      0.30      0.29      0.32    770.46      1.00
  beta[9,6]      0.10      0.01      0.10      0.09      0.11    672.34      1.00
  beta[9,7]      0.34      0.01      0.34      0.33      0.36    590.69      1.00
  beta[9,8]      0.19      0.01      0.19      0.17      0.20    701.83      1.00
  beta[9,9]      0.31      0.01      0.31      0.30      0.32    491.28      1.00
 beta[9,10]      0.56      0.01      0.56      0.55      0.58    501.94      1.00
 beta[9,11]      0.11      0.01      0.11      0.10      0.13    458.62      1.00
 beta[9,12]      0.19      0.01      0.19      0.17      0.20    491.62      1.01
 beta[9,13]      0.09      0.00      0.09      0.08      0.09    694.76      1.00
 beta[10,0]      0.06      0.01      0.06      0.05      0.07    661.86      1.00
 beta[10,1]     -0.00      0.01     -0.00     -0.02      0.01    560.89      1.00
 beta[10,2]      0.07      0.01      0.07      0.05      0.09    487.57      1.00
 beta[10,3]     -0.07      0.01     -0.07     -0.08     -0.05    538.75      1.00
 beta[10,4]      0.32      0.01      0.32      0.31      0.34    582.82      1.00
 beta[10,5]      0.30      0.01      0.30      0.29      0.32    441.16      1.00
 beta[10,6]      0.10      0.01      0.10      0.09      0.12    413.37      1.01
 beta[10,7]      0.34      0.01      0.34      0.33      0.36    529.81      1.02
 beta[10,8]      0.19      0.01      0.19      0.18      0.20    615.98      1.01
 beta[10,9]      0.31      0.01      0.31      0.29      0.32    518.03      1.02
beta[10,10]      0.56      0.01      0.56      0.55      0.57    549.82      1.00
beta[10,11]      0.11      0.01      0.11      0.10      0.13    526.15      1.00
beta[10,12]      0.19      0.01      0.19      0.17      0.20    506.01      1.00
beta[10,13]      0.09      0.00      0.09      0.08      0.09    626.06      1.00
 beta[11,0]      0.06      0.01      0.06      0.05      0.07    656.89      1.00
 beta[11,1]     -0.00      0.01     -0.00     -0.02      0.01    501.48      1.01
 beta[11,2]      0.07      0.01      0.07      0.05      0.08    434.22      1.01
 beta[11,3]     -0.07      0.01     -0.07     -0.08     -0.06    455.18      1.01
 beta[11,4]      0.32      0.01      0.32      0.31      0.34    609.90      1.00
 beta[11,5]      0.30      0.01      0.30      0.29      0.31    742.15      1.00
 beta[11,6]      0.10      0.01      0.10      0.09      0.12    811.31      1.00
 beta[11,7]      0.34      0.01      0.34      0.33      0.36    572.83      1.00
 beta[11,8]      0.19      0.01      0.19      0.17      0.20    642.85      1.00
 beta[11,9]      0.31      0.01      0.31      0.30      0.32    541.95      1.00
beta[11,10]      0.56      0.01      0.56      0.55      0.58    488.09      1.00
beta[11,11]      0.11      0.01      0.11      0.10      0.13    400.39      1.00
beta[11,12]      0.19      0.01      0.19      0.17      0.20    502.89      1.00
beta[11,13]      0.09      0.00      0.09      0.08      0.09    770.16      1.00
 beta[12,0]      0.06      0.00      0.06      0.05      0.06    835.10      1.00
 beta[12,1]     -0.00      0.01     -0.00     -0.02      0.01    525.41      1.00
 beta[12,2]      0.07      0.01      0.07      0.05      0.09    461.99      1.00
 beta[12,3]     -0.07      0.01     -0.07     -0.08     -0.05    402.14      1.00
 beta[12,4]      0.32      0.01      0.32      0.31      0.33    536.10      1.00
 beta[12,5]      0.30      0.01      0.30      0.29      0.32    508.00      1.00
 beta[12,6]      0.10      0.01      0.10      0.09      0.12    432.68      1.00
 beta[12,7]      0.34      0.01      0.34      0.33      0.36    456.73      1.00
 beta[12,8]      0.19      0.01      0.19      0.18      0.20    454.93      1.00
 beta[12,9]      0.31      0.01      0.31      0.29      0.32    429.57      1.00
beta[12,10]      0.56      0.01      0.56      0.55      0.58    428.65      1.00
beta[12,11]      0.11      0.01      0.11      0.09      0.13    467.15      1.00
beta[12,12]      0.19      0.01      0.19      0.17      0.20    514.43      1.00
beta[12,13]      0.09      0.01      0.09      0.08      0.09    648.53      1.00
 beta[13,0]      0.06      0.00      0.06      0.05      0.06    693.73      1.00
 beta[13,1]     -0.00      0.01     -0.00     -0.02      0.01    571.39      1.00
 beta[13,2]      0.07      0.01      0.07      0.05      0.08    705.45      1.00
 beta[13,3]     -0.07      0.01     -0.07     -0.08     -0.05    672.76      1.00
 beta[13,4]      0.32      0.01      0.32      0.31      0.34    675.63      1.00
 beta[13,5]      0.30      0.01      0.30      0.29      0.32    508.98      1.00
 beta[13,6]      0.10      0.01      0.10      0.09      0.12    562.72      1.00
 beta[13,7]      0.34      0.01      0.34      0.33      0.36    554.37      1.00
 beta[13,8]      0.19      0.01      0.19      0.18      0.20    516.77      1.00
 beta[13,9]      0.31      0.01      0.31      0.29      0.32    493.20      1.00
beta[13,10]      0.56      0.01      0.56      0.55      0.58    494.85      1.00
beta[13,11]      0.11      0.01      0.11      0.09      0.13    510.48      1.00
beta[13,12]      0.19      0.01      0.19      0.17      0.20    676.12      1.00
beta[13,13]      0.09      0.00      0.09      0.08      0.10    843.17      1.00
      sigma      0.13      0.00      0.13      0.13      0.13    562.53      1.00
        tau      0.21      0.01      0.21      0.19      0.23    968.57      1.00

Number of divergences: 0
splines_grw_samples = splines_grw_mcmc.get_samples(1000)
splines_grw_post_pred = Predictive(
    splines_grw,
    splines_grw_samples
)(B, None)
# mean -> mean b/c I take mean of GRW first, then mean of
# 1000 posterior samples
sns.lineplot(
    x=hour['hr'], y=splines_grw_post_pred['y'].mean(axis=1).mean(axis=0), 
    color='blue', lw=2, label='splines_grw mean function'
    )
sns.scatterplot(
    x=hour['hr'], y=hour['cnt_std'].values, 
    color='lightgrey', alpha=0.5, edgecolor='grey'
    )
plt.title('Gaussian Random Walk Prior');

Modeling \(\text{CO}_2\) Uptake with Splines

plants_CO2 = pd.read_csv("./data/CO2_uptake.csv")
plant_names = plants_CO2.Plant.unique()
CO2_conc = plants_CO2.conc.values[:7]
CO2_concs = plants_CO2.conc.values
uptake = plants_CO2.uptake.values
index = range(12)
groups = len(index)
num_knots = 2
knot_list = np.linspace(CO2_conc[0], CO2_conc[-1], num_knots+2)[1:-1]

Bg = dmatrix(
    "bs(conc, knots=knots, degree=3, include_intercept=True) - 1",
    {"conc": CO2_concs, "knots": knot_list},
)

Bg = torch.tensor(np.asarray(Bg)).float()
uptake = torch.tensor(uptake).float()

Pooled Model - MCMC

def single_response(design_matrix, obs=None):

    N, P = design_matrix.shape
    
    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))

    with pyro.plate('coef', P):
        beta = pyro.sample('beta', dist.Normal(0., tau))

    ug = pyro.deterministic('ug', torch.matmul(beta, design_matrix.T))

    with pyro.plate('obs', N):
        up = pyro.sample('uptake', dist.Normal(ug, sigma), obs=obs)
pyro.render_model(
    single_response,
    (Bg, uptake), 
    render_distributions=True
)

sr_mcmc = MCMC(NUTS(single_response), 500, 300)
sr_mcmc.run(Bg, uptake)
Sample: 100%|██████████| 800/800 [00:38, 20.99it/s, step size=1.98e-01, acc. prob=0.934]
sr_mcmc.summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
   beta[0]     12.21      1.82     12.27      9.20     15.06    261.26      1.00
   beta[1]     30.06      3.52     30.11     24.50     35.95    251.13      1.00
   beta[2]     30.06      6.05     29.81     20.25     39.69    193.86      1.00
   beta[3]     33.86      9.16     33.72     20.00     49.44    171.13      1.00
   beta[4]     25.27     22.92     24.96    -14.30     60.07    186.48      1.01
   beta[5]     33.41      2.03     33.29     30.06     36.65    395.37      1.00
     sigma      6.77      0.38      6.75      6.24      7.48    596.67      1.00
       tau     31.24     10.56     28.94     17.24     47.38    226.15      1.00

Number of divergences: 0
sr_samples = sr_mcmc.get_samples(1000)
sr_post_pred = Predictive(
    single_response, 
    sr_samples
)(Bg, None)
fig, axes = plt.subplots(4, 3, figsize=(10, 6), sharey=True, sharex=True)

for count, (idx, ax) in enumerate(zip(range(0, 84, 7), axes.ravel())):
    ax.plot(CO2_conc, uptake[idx:idx+7], '.', lw=1, color='black')
    ax.plot(CO2_conc, sr_post_pred['uptake'].mean(axis=0)[idx:idx+7], "k", alpha=0.5);
    az.plot_hdi(CO2_conc, sr_post_pred['uptake'][:,idx:idx+7], color="C2", smooth=False, ax=ax)
    ax.set_title(plant_names[count])

plt.tight_layout()
fig.text(0.4, -0.05, "CO2 concentration", size=18)
fig.text(-0.03, 0.4, "CO2 uptake", size=18, rotation=90);
/Users/gabestechschulte/miniforge3/envs/probs/lib/python3.10/site-packages/arviz/plots/hdiplot.py:156: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)

num_knots = 2
knot_list = np.linspace(CO2_conc[0], CO2_conc[-1], num_knots+2)[1:-1]

Bi = dmatrix(
    "bs(conc, knots=knots, degree=3, include_intercept=True) - 1",
    {"conc": CO2_conc, "knots": knot_list},
)

Bi = torch.tensor(np.asarray(Bi)).float()

Mixed Effects Model - MCMC

def individual_response(design_matrix, groups, obs=None):

    N, P = design_matrix.size()
    
    tau = pyro.sample('tau', dist.HalfCauchy(1.))
    sigma = pyro.sample('sigma', dist.HalfNormal(1.))
    beta = pyro.sample('beta', dist.Normal(0., tau).expand([P, groups]))
    ug = pyro.deterministic('ug', torch.matmul(design_matrix, beta))
    ug = ug[:, index].T.ravel()

    with pyro.plate('obs', ug.size(0)):
        up = pyro.sample('uptake', dist.Normal(ug, sigma), obs=obs)
ir_mcmc = MCMC(NUTS(individual_response), 500, 300)
ir_mcmc.run(Bi, groups, uptake)
Sample: 100%|██████████| 800/800 [01:03, 12.67it/s, step size=1.36e-01, acc. prob=0.895]
ir_mcmc.summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
 beta[0,0]     15.89      2.02     15.80     12.85     19.17    749.09      1.00
 beta[0,1]     13.24      1.97     13.24     10.27     16.83    818.95      1.00
 beta[0,2]     16.11      2.09     16.17     13.20     20.12    861.96      1.00
 beta[0,3]     14.01      2.04     13.96     10.65     17.05   1070.16      1.00
 beta[0,4]      9.39      2.04      9.29      6.12     12.53    646.55      1.00
 beta[0,5]     13.72      2.06     13.76     10.31     17.12    797.97      1.00
 beta[0,6]     10.36      1.84     10.24      7.20     13.16    733.61      1.00
 beta[0,7]     11.61      1.96     11.58      8.36     14.59    499.54      1.00
 beta[0,8]     11.02      1.85     10.98      7.65     13.60    414.56      1.00
 beta[0,9]     10.31      1.97     10.33      6.76     13.22    907.83      1.01
beta[0,10]      7.64      2.00      7.59      4.22     10.82    602.67      1.00
beta[0,11]     10.94      2.01     11.01      7.74     14.01    803.46      1.00
 beta[1,0]     41.13      3.72     40.87     35.36     47.45    294.46      1.00
 beta[1,1]     37.47      4.01     37.73     30.86     43.60    274.05      1.00
 beta[1,2]     44.81      3.81     44.62     39.36     51.19    373.35      1.00
 beta[1,3]     31.45      3.68     31.47     26.03     38.26    499.04      1.00
 beta[1,4]     39.25      3.74     39.25     33.99     45.84    327.53      1.00
 beta[1,5]     33.63      3.66     33.62     27.48     39.48    535.10      1.00
 beta[1,6]     25.75      3.68     25.54     20.20     31.82    485.46      1.00
 beta[1,7]     30.41      3.55     30.30     25.06     36.66    319.34      1.00
 beta[1,8]     25.89      3.73     25.82     19.97     31.73    307.24      1.00
 beta[1,9]     19.24      3.86     19.20     12.78     24.82    514.52      1.01
beta[1,10]     14.15      3.80     14.05      8.83     20.76    286.51      1.00
beta[1,11]     22.33      3.90     22.39     16.56     28.69    250.93      1.00
 beta[2,0]     30.70      5.89     30.87     21.00     39.83    218.51      1.00
 beta[2,1]     43.32      6.25     43.17     31.48     52.51    365.20      1.00
 beta[2,2]     38.19      6.03     38.38     28.38     47.50    315.04      1.00
 beta[2,3]     34.33      6.01     33.96     25.13     44.26    376.64      1.00
 beta[2,4]     36.97      5.78     36.70     27.90     46.70    287.45      1.00
 beta[2,5]     36.23      5.68     36.39     26.48     45.10    335.15      1.00
 beta[2,6]     31.01      5.85     31.13     20.69     39.98    387.93      1.00
 beta[2,7]     33.15      5.64     33.24     24.02     41.86    238.67      1.00
 beta[2,8]     28.14      5.80     28.03     18.82     37.71    262.62      1.01
 beta[2,9]     16.99      6.29     16.97      5.62     26.27    375.90      1.00
beta[2,10]     10.97      6.47     10.75      0.52     20.80    246.43      1.00
beta[2,11]     13.48      6.39     13.07      3.82     23.94    166.02      1.00
 beta[3,0]     43.46      9.05     43.38     30.07     58.98    160.78      1.01
 beta[3,1]     40.72      9.80     40.54     23.31     55.08    375.00      1.00
 beta[3,2]     51.19      9.04     51.03     36.54     65.81    279.54      1.00
 beta[3,3]     33.55      9.70     33.58     16.81     48.18    380.20      1.00
 beta[3,4]     42.53      8.86     42.41     29.50     57.63    276.04      1.00
 beta[3,5]     44.49      8.99     44.56     29.90     59.10    320.28      1.00
 beta[3,6]     34.04      9.10     33.86     18.79     48.64    388.05      1.00
 beta[3,7]     32.81      9.16     32.71     18.91     47.81    245.83      1.00
 beta[3,8]     31.09      9.19     31.39     13.62     43.80    248.20      1.00
 beta[3,9]     24.42      9.38     24.31     10.53     40.01    389.07      1.00
beta[3,10]     15.10      9.82     15.47     -1.38     30.21    228.23      1.00
beta[3,11]     23.38      9.72     23.52      8.32     39.32    160.46      1.00
 beta[4,0]     32.31     22.64     31.24     -3.10     70.85    241.13      1.00
 beta[4,1]     37.32     24.16     38.20      3.44     80.53    407.01      1.00
 beta[4,2]     24.10     22.41     23.79    -11.63     62.44    347.81      1.00
 beta[4,3]     37.11     24.25     38.23     -3.03     70.66    349.09      1.00
 beta[4,4]     20.83     22.29     21.67    -13.15     58.67    329.19      1.00
 beta[4,5]     25.82     23.11     26.21     -9.84     68.61    371.96      1.00
 beta[4,6]     26.63     23.04     25.68    -16.67     60.41    422.46      1.00
 beta[4,7]     21.17     23.67     22.25    -16.18     60.12    273.25      1.00
 beta[4,8]     17.28     23.27     17.68    -21.30     52.39    324.63      1.00
 beta[4,9]     18.38     23.02     18.19    -13.46     55.14    419.31      1.00
beta[4,10]     11.79     23.83     11.44    -25.03     48.55    253.85      1.00
beta[4,11]      9.50     23.80      9.69    -26.23     49.25    195.24      1.00
 beta[5,0]     39.49      1.95     39.52     36.04     42.50   1449.89      1.00
 beta[5,1]     44.12      1.93     44.08     40.64     46.90   1125.14      1.00
 beta[5,2]     45.41      2.21     45.38     41.92     48.87    468.95      1.00
 beta[5,3]     38.56      2.18     38.59     34.91     42.08    661.23      1.00
 beta[5,4]     42.21      2.11     42.32     38.65     45.23    615.67      1.00
 beta[5,5]     41.19      1.98     41.11     38.21     44.66    756.89      1.00
 beta[5,6]     35.32      2.11     35.43     31.69     38.64   1030.06      1.00
 beta[5,7]     31.34      2.05     31.39     27.53     34.23    928.45      1.00
 beta[5,8]     27.55      1.94     27.57     24.71     31.08   1275.81      1.00
 beta[5,9]     21.57      1.89     21.64     18.33     24.45    982.11      1.00
beta[5,10]     14.30      2.08     14.28     10.81     17.45    958.17      1.00
beta[5,11]     19.76      1.96     19.83     16.75     23.04    805.23      1.00
     sigma      2.04      0.29      2.02      1.56      2.48     86.44      1.00
       tau     31.47      2.90     31.15     26.90     36.18    586.94      1.00

Number of divergences: 0
ir_samples = ir_mcmc.get_samples(1000)
ir_post_pred = Predictive(
    individual_response, 
    ir_samples
)(Bi, groups, None)
fig, axes = plt.subplots(4, 3, figsize=(10, 6), sharey=True, sharex=True)

for count, (idx, ax) in enumerate(zip(range(0, 84, 7), axes.ravel())):
    ax.plot(CO2_conc, uptake[idx:idx+7], '.', lw=1, color='black')
    ax.plot(CO2_conc, ir_post_pred['uptake'].mean(axis=0)[idx:idx+7], "k", alpha=0.5);
    az.plot_hdi(CO2_conc, ir_post_pred['uptake'][:,idx:idx+7], color="C2", smooth=False, ax=ax)
    ax.set_title(plant_names[count])

plt.tight_layout()
fig.text(0.4, -0.05, "CO2 concentration", size=18)
fig.text(-0.03, 0.4, "CO2 uptake", size=18, rotation=90);
/Users/gabestechschulte/miniforge3/envs/probs/lib/python3.10/site-packages/arviz/plots/hdiplot.py:156: FutureWarning: hdi currently interprets 2d data as (draw, shape) but this will change in a future release to (chain, draw) for coherence with other functions
  hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)