import arviz as az
import bambi as bmb
import bayeux as bx
import numpy as np
import pandas as pd
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Gabriel Stechschulte
March 29, 2024
This blog post is a copy of the alternative samplers documentation I wrote for Bambi. The original post can be found here.
In Bambi, the sampler used is automatically selected given the type of variables used in the model. For inference, Bambi supports both MCMC and variational inference. By default, Bambi uses PyMC’s implementation of the adaptive Hamiltonian Monte Carlo (HMC) algorithm for sampling. Also known as the No-U-Turn Sampler (NUTS). This sampler is a good choice for many models. However, it is not the only sampling method, nor is PyMC the only library implementing NUTS.
To this extent, Bambi supports multiple backends for MCMC sampling such as NumPyro and Blackjax. This notebook will cover how to use such alternatives in Bambi.
Note: Bambi utilizes bayeux to access a variety of sampling backends. Thus, you will need to install the optional dependencies in the Bambi pyproject.toml file to use these backends.
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Bambi leverages bayeux
to access different sampling backends. In short, bayeux
lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods.
Since the underlying Bambi model is a PyMC model, this PyMC model can be “given” to bayeux
. Then, we can choose from a variety of MCMC methods to perform inference.
To demonstrate the available backends, we will fist simulate data and build a model.
num_samples = 100
num_features = 1
noise_std = 1.0
random_seed = 42
np.random.seed(random_seed)
coefficients = np.random.randn(num_features)
X = np.random.randn(num_samples, num_features)
error = np.random.normal(scale=noise_std, size=num_samples)
y = X @ coefficients + error
data = pd.DataFrame({"y": y, "x": X.flatten()})
We can call bmb.inference_methods.names
that returns a nested dictionary of the backends and list of inference methods.
{'pymc': {'mcmc': ['mcmc'], 'vi': ['vi']},
'bayeux': {'mcmc': ['tfp_hmc',
'tfp_nuts',
'tfp_snaper_hmc',
'blackjax_hmc',
'blackjax_chees_hmc',
'blackjax_meads_hmc',
'blackjax_nuts',
'blackjax_hmc_pathfinder',
'blackjax_nuts_pathfinder',
'flowmc_rqspline_hmc',
'flowmc_rqspline_mala',
'flowmc_realnvp_hmc',
'flowmc_realnvp_mala',
'numpyro_hmc',
'numpyro_nuts']}}
With the PyMC backend, we have access to their implementation of the NUTS sampler and mean-field variational inference.
bayeux
lets us have access to Tensorflow probability, Blackjax, FlowMC, and NumPyro backends.
{'mcmc': ['tfp_hmc',
'tfp_nuts',
'tfp_snaper_hmc',
'blackjax_hmc',
'blackjax_chees_hmc',
'blackjax_meads_hmc',
'blackjax_nuts',
'blackjax_hmc_pathfinder',
'blackjax_nuts_pathfinder',
'flowmc_rqspline_hmc',
'flowmc_rqspline_mala',
'flowmc_realnvp_hmc',
'flowmc_realnvp_mala',
'numpyro_hmc',
'numpyro_nuts']}
The values of the MCMC and VI keys in the dictionary are the names of the argument you would pass to inference_method
in model.fit
. This is shown in the section below.
inference_method
By default, Bambi uses the PyMC NUTS implementation. To use a different backend, pass the name of the bayeux
MCMC method to the inference_method
parameter of the fit
method.
<xarray.Dataset> Size: 100kB Dimensions: (chain: 8, draw: 500) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: Intercept (chain, draw) float64 32kB 0.04421 0.1077 ... 0.0259 0.06753 x (chain, draw) float64 32kB 0.1353 0.232 0.5141 ... 0.2195 0.5014 y_sigma (chain, draw) float64 32kB 0.9443 0.9102 0.922 ... 0.9597 0.9249 Attributes: created_at: 2024-04-13T05:34:49.761913+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3, 4, 5, 6, 7])
array([ 0, 1, 2, ..., 497, 498, 499])
array([[ 0.04420881, 0.10774889, -0.01477631, ..., 0.05546068, -0.01057083, 0.09897323], [ 0.00694954, 0.10744512, -0.017276 , ..., 0.19618115, 0.06402486, -0.01106827], [ 0.16110577, -0.07458938, 0.04475104, ..., 0.16745381, -0.00406837, 0.07311051], ..., [ 0.09943931, -0.03684845, 0.09735818, ..., -0.06556524, 0.11011645, 0.08414361], [-0.07703379, 0.02738655, 0.02285994, ..., 0.14379745, -0.10339471, -0.02836366], [ 0.04903997, -0.03220716, -0.02720002, ..., 0.17203999, 0.02589751, 0.06752773]])
array([[0.13526029, 0.23196226, 0.51413147, ..., 0.23278954, 0.32745043, 0.37862773], [0.40584773, 0.51513052, 0.2268538 , ..., 0.41687492, 0.30601076, 0.2634667 ], [0.41543724, 0.44571834, 0.23530532, ..., 0.6172463 , 0.29822452, 0.45765768], ..., [0.49946851, 0.29694244, 0.44142996, ..., 0.26425056, 0.46471836, 0.32217591], [0.41877449, 0.33327679, 0.4045056 , ..., 0.66448843, 0.24280931, 0.50115044], [0.51180277, 0.42393989, 0.56394504, ..., 0.29234944, 0.21949889, 0.5013853 ]])
array([[0.94428889, 0.91016104, 0.92196855, ..., 0.83634906, 0.79627853, 1.08163408], [0.87025311, 0.85044922, 0.91347637, ..., 1.0028945 , 0.77749843, 0.87518191], [0.94615571, 0.84280628, 1.05011189, ..., 1.0255364 , 0.96478417, 0.9140493 ], ..., [0.87146472, 1.04641364, 0.86900166, ..., 0.91303204, 0.95041789, 0.96797332], [0.94906021, 0.99194229, 0.84058257, ..., 0.99087914, 0.96639345, 0.99059172], [0.91025793, 0.8993632 , 1.03222263, ..., 0.9717563 , 0.95967178, 0.92491709]])
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 490, 491, 492, 493, 494, 495, 496, 497, 498, 499], dtype='int64', name='draw', length=500))
<xarray.Dataset> Size: 200kB Dimensions: (chain: 8, draw: 500) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 4kB 0 1 2 3 4 5 6 ... 494 495 496 497 498 499 Data variables: acceptance_rate (chain, draw) float64 32kB 0.8137 1.0 1.0 ... 0.9094 0.9834 diverging (chain, draw) bool 4kB False False False ... False False energy (chain, draw) float64 32kB 142.0 142.5 ... 143.0 141.5 lp (chain, draw) float64 32kB -141.8 -140.8 ... -140.3 -140.3 n_steps (chain, draw) int64 32kB 3 3 7 7 7 3 3 7 ... 7 3 7 5 7 3 7 step_size (chain, draw) float64 32kB 0.7326 0.7326 ... 0.7643 0.7643 tree_depth (chain, draw) int64 32kB 2 2 3 3 3 2 2 3 ... 3 2 3 3 3 2 3 Attributes: created_at: 2024-04-13T05:34:49.763427+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3, 4, 5, 6, 7])
array([ 0, 1, 2, ..., 497, 498, 499])
array([[0.81368466, 1. , 1. , ..., 0.95726715, 0.95332204, 1. ], [0.98623155, 0.76947694, 1. , ..., 0.6501189 , 0.6980205 , 1. ], [0.98539058, 0.82802334, 0.96559601, ..., 0.72850635, 1. , 0.86563511], ..., [0.79238494, 0.97989654, 0.94005541, ..., 0.98283263, 0.99321313, 0.92314755], [0.95823733, 0.94198648, 0.91853339, ..., 0.68699656, 0.972578 , 0.74390253], [0.99181102, 0.97429544, 0.78790853, ..., 1. , 0.90941548, 0.98341956]])
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
array([[141.967537 , 142.54106626, 141.78398912, ..., 142.07442022, 143.72755872, 142.21633731], [142.4254147 , 142.09857076, 141.94644772, ..., 142.39410249, 146.51261269, 143.65723574], [141.4414682 , 142.96701481, 142.39395521, ..., 144.02150538, 142.66721068, 140.85225446], ..., [142.45649528, 142.75638401, 142.59874467, ..., 141.35824722, 140.94450824, 141.2808887 ], [140.3251331 , 141.16320788, 140.88902952, ..., 144.52035375, 144.09031991, 143.92351871], [142.65283365, 141.01504212, 142.60582761, ..., 143.46419056, 142.97607812, 141.46662296]])
array([[-141.78590569, -140.75355135, -140.61320047, ..., -141.63502666, -142.12600187, -141.57227603], [-139.88014501, -141.79255751, -140.226333 , ..., -141.301519 , -143.3329595 , -140.20584575], [-140.39038429, -141.56925705, -141.27509741, ..., -143.36048355, -139.58615368, -139.84922801], ..., [-140.99893216, -140.82540718, -140.38825538, ..., -140.12098164, -140.10850196, -139.71074945], [-140.09932106, -139.69086444, -140.49414807, ..., -143.90263595, -140.69641315, -140.7183776 ], [-140.46042516, -139.8366111 , -142.15416918, ..., -140.96117584, -140.27772734, -140.27024162]])
array([[ 3, 3, 7, ..., 7, 7, 7], [ 3, 3, 3, ..., 3, 7, 3], [11, 3, 3, ..., 3, 3, 3], ..., [ 7, 7, 7, ..., 7, 7, 3], [ 3, 3, 3, ..., 7, 7, 3], [ 7, 3, 3, ..., 7, 3, 7]])
array([[0.73264667, 0.73264667, 0.73264667, ..., 0.73264667, 0.73264667, 0.73264667], [0.84139296, 0.84139296, 0.84139296, ..., 0.84139296, 0.84139296, 0.84139296], [0.90832794, 0.90832794, 0.90832794, ..., 0.90832794, 0.90832794, 0.90832794], ..., [0.75868138, 0.75868138, 0.75868138, ..., 0.75868138, 0.75868138, 0.75868138], [0.83356209, 0.83356209, 0.83356209, ..., 0.83356209, 0.83356209, 0.83356209], [0.76429536, 0.76429536, 0.76429536, ..., 0.76429536, 0.76429536, 0.76429536]])
array([[2, 2, 3, ..., 3, 3, 3], [2, 2, 2, ..., 2, 3, 2], [4, 2, 2, ..., 2, 2, 2], ..., [3, 3, 3, ..., 3, 3, 2], [2, 2, 2, ..., 3, 3, 2], [3, 2, 2, ..., 3, 2, 3]])
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 490, 491, 492, 493, 494, 495, 496, 497, 498, 499], dtype='int64', name='draw', length=500))
Different backends have different naming conventions for the parameters specific to that MCMC method. Thus, to specify backend-specific parameters, pass your own kwargs
to the fit
method.
The following can be performend to identify the kwargs specific to each method.
{<function blackjax.adaptation.window_adaptation.window_adaptation(algorithm: Union[blackjax.mcmc.hmc.hmc, blackjax.mcmc.nuts.nuts], logdensity_fn: Callable, is_mass_matrix_diagonal: bool = True, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.8, progress_bar: bool = False, **extra_parameters) -> blackjax.base.AdaptationAlgorithm>: {'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
'is_mass_matrix_diagonal': True,
'initial_step_size': 1.0,
'target_acceptance_rate': 0.8,
'progress_bar': False,
'algorithm': blackjax.mcmc.nuts.nuts},
'adapt.run': {'num_steps': 500},
blackjax.mcmc.nuts.nuts: {'max_num_doublings': 10,
'divergence_threshold': 1000,
'integrator': <function blackjax.mcmc.integrators.generate_euclidean_integrator.<locals>.euclidean_integrator(logdensity_fn: Callable, kinetic_energy_fn: blackjax.mcmc.metrics.KineticEnergy) -> Callable[[blackjax.mcmc.integrators.IntegratorState, float], blackjax.mcmc.integrators.IntegratorState]>,
'logdensity_fn': <function bayeux._src.shared.constrain.<locals>.wrap_log_density.<locals>.wrapped(args)>,
'step_size': 0.5},
'extra_parameters': {'chain_method': 'vectorized',
'num_chains': 8,
'num_draws': 500,
'num_adapt_draws': 500,
'return_pytree': False}}
Now, we can identify the kwargs we would like to change and pass to the fit
method.
kwargs = {
"adapt.run": {"num_steps": 500},
"num_chains": 4,
"num_draws": 250,
"num_adapt_draws": 250
}
blackjax_nuts_idata = model.fit(inference_method="blackjax_nuts", **kwargs)
blackjax_nuts_idata
<xarray.Dataset> Size: 26kB Dimensions: (chain: 4, draw: 250) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 7 ... 243 244 245 246 247 248 249 Data variables: Intercept (chain, draw) float64 8kB 0.112 -0.08016 ... -0.04784 -0.04427 x (chain, draw) float64 8kB 0.4311 0.3077 0.2568 ... 0.557 0.4814 y_sigma (chain, draw) float64 8kB 0.9461 0.9216 0.9021 ... 0.9574 0.9414 Attributes: created_at: 2024-04-13T05:36:20.439151+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 247, 248, 249])
array([[ 1.12014970e-01, -8.01554356e-02, -5.13623666e-02, 3.14326884e-02, -1.15157203e-02, 8.29955919e-02, 1.93598964e-02, 3.52348779e-02, 1.16428197e-01, 8.72568457e-02, 8.72568457e-02, -2.82423636e-02, -1.84281659e-01, -7.19613600e-02, 1.91168564e-01, -1.23364451e-01, -1.19337859e-01, -3.08974539e-02, -7.18298964e-02, 1.19861654e-01, -1.00735526e-01, 6.32700448e-02, -5.72727355e-03, 1.17338655e-01, -3.83727222e-02, 1.84746578e-01, 7.26989694e-02, 1.94564654e-01, 1.44827239e-01, -8.50718497e-02, 1.87666705e-01, -1.60104555e-01, -4.99203302e-02, 3.19979141e-03, -3.94543945e-03, 8.66672888e-04, 4.33390733e-02, 5.22653528e-02, -4.77497368e-02, -1.88745071e-02, 5.03627424e-02, 8.24434767e-02, -3.79889140e-03, 7.70856139e-03, -4.77259521e-02, -1.90736318e-02, 6.76733158e-02, 5.14461069e-02, -4.56113715e-02, 1.60543248e-01, 7.32836299e-02, 2.28842579e-03, 3.05194139e-02, 4.27895103e-02, -2.51634507e-02, -4.60161935e-02, 2.44964388e-01, 2.76318153e-01, 4.33818171e-02, 1.46208904e-01, ... -5.68838115e-02, -6.14275201e-02, -5.96425618e-02, 6.11356758e-02, 6.42661723e-03, 5.41912583e-02, -1.76976244e-01, -4.62930404e-02, 9.61963932e-02, -1.40433636e-01, 2.05056910e-01, 1.82385197e-01, 1.21005125e-01, -9.65523825e-02, 9.11450646e-02, 1.49525640e-02, -8.32289763e-02, 3.24479331e-02, 1.09007071e-02, -6.92830705e-02, 6.64926592e-02, 3.23060974e-02, -1.73437807e-01, -1.25619389e-03, 8.89183729e-02, 1.02309051e-01, 4.12736086e-02, 1.03893380e-01, 6.89267255e-02, 1.37649597e-01, -7.63849028e-02, 7.69987215e-02, -1.14605433e-01, -1.59066163e-01, 2.02049201e-01, 1.67222994e-01, -5.02468032e-02, -1.17601875e-01, -1.67595598e-03, -1.97669449e-01, 4.36079372e-02, 8.41929183e-02, 1.31836071e-01, 1.65427331e-01, 1.26585460e-01, -7.27516393e-02, 5.41849189e-02, 2.21844869e-02, 5.94315594e-02, 5.94315594e-02, 7.45566691e-02, -9.33357688e-03, -4.93686976e-02, -3.43353187e-02, 6.89221401e-02, -3.19652375e-02, -4.78438315e-02, -4.42738699e-02]])
array([[0.43113703, 0.30774729, 0.25682166, 0.31641395, 0.51546729, 0.31100976, 0.44043531, 0.38164497, 0.42080014, 0.30069041, 0.30069041, 0.53305267, 0.11871418, 0.22919114, 0.22043383, 0.32368544, 0.28827739, 0.44216387, 0.38292596, 0.38328387, 0.31277052, 0.28380182, 0.39125062, 0.5436668 , 0.19914823, 0.23381157, 0.3952613 , 0.48281672, 0.27598205, 0.46597795, 0.48635971, 0.21363092, 0.39350997, 0.42601567, 0.3035345 , 0.26553072, 0.44019149, 0.34397815, 0.23609522, 0.53683168, 0.45841485, 0.23891478, 0.54442998, 0.16697332, 0.19146859, 0.22799538, 0.39366724, 0.37134365, 0.34501806, 0.37506017, 0.28311981, 0.16254121, 0.61289656, 0.13063232, 0.03017502, 0.18434623, 0.36065819, 0.52235008, 0.24458848, 0.14313226, 0.22279879, 0.44892021, 0.3952106 , 0.34290512, 0.42439318, 0.23102895, 0.19110882, 0.25093658, 0.37681057, 0.36135287, 0.30745033, 0.27562781, 0.31724922, 0.45716849, 0.47116505, 0.43884602, 0.43553571, 0.29161261, 0.41998198, 0.40796597, 0.30405689, 0.31259796, 0.20570747, 0.39392466, 0.31348596, 0.4214938 , 0.52463068, 0.16792862, 0.28029374, 0.16153929, 0.16724633, 0.08144633, 0.19192458, 0.34938819, 0.13305379, 0.13881198, 0.37849938, 0.37084368, 0.27404992, 0.46209003, ... 0.36307521, 0.28227501, 0.38551525, 0.23261809, 0.36514994, 0.35783934, 0.3823261 , 0.30670976, 0.32886498, 0.37068029, 0.34013729, 0.52474148, 0.639228 , 0.09371894, 0.42023135, 0.36733482, 0.40599032, 0.24382963, 0.40221825, 0.3047073 , 0.24407962, 0.30213837, 0.44363912, 0.57883031, 0.48781764, 0.48525391, 0.29198732, 0.37745175, 0.31777746, 0.28262044, 0.18702892, 0.51867304, 0.52340339, 0.24125419, 0.23332597, 0.46851727, 0.46079104, 0.32301517, 0.35635714, 0.33111389, 0.37437903, 0.28657985, 0.43734974, 0.35478284, 0.30887643, 0.49867288, 0.2525673 , 0.33079942, 0.02800324, 0.31465776, 0.4585882 , 0.28368126, 0.4896697 , 0.44762422, 0.41453835, 0.246885 , 0.37482138, 0.40059614, 0.27591068, 0.21900013, 0.47128275, 0.21132567, 0.39900367, 0.2329504 , 0.39579287, 0.37344961, 0.34516518, 0.32227915, 0.35271413, 0.37687565, 0.31151605, 0.37301695, 0.26012957, 0.30024098, 0.2745939 , 0.25698144, 0.37064686, 0.43608796, 0.29833848, 0.4057974 , 0.37998817, 0.3505483 , 0.3385325 , 0.29122156, 0.46273061, 0.26565498, 0.27156778, 0.37355743, 0.30409736, 0.34103447, 0.33149781, 0.33149781, 0.43853323, 0.27834059, 0.36111445, 0.40162141, 0.42356887, 0.55111442, 0.55695307, 0.48135813]])
array([[0.94608021, 0.92160426, 0.90211804, 0.8784906 , 1.00443151, 0.96727798, 0.87524178, 0.96817024, 0.86457884, 1.0262922 , 1.0262922 , 0.84391933, 0.98129848, 1.00988048, 0.99554283, 0.88587989, 1.00586764, 0.92997466, 0.94168997, 0.99374899, 0.92810392, 0.96878376, 0.90682728, 0.94883541, 1.01957489, 1.02278733, 0.90160336, 0.94497765, 0.87888132, 0.92205875, 0.9138956 , 0.96519328, 1.06316311, 0.84402636, 0.83729644, 0.89811997, 0.97144791, 0.98208145, 0.91289233, 0.96673035, 0.95542624, 0.91245841, 0.96527727, 0.92783747, 1.03786087, 0.94764661, 1.00547045, 0.85588467, 0.98223118, 0.8674327 , 0.94037555, 0.91725845, 0.99391199, 0.92434293, 0.9638643 , 1.08815478, 1.01399545, 1.02349856, 0.92934388, 0.96598116, 0.96311436, 0.93945143, 0.89124759, 0.98455184, 0.89591612, 1.006701 , 0.95597051, 1.00027136, 0.91409196, 0.97378494, 0.87137146, 0.87160277, 1.04749666, 0.8805835 , 0.8819731 , 0.88645983, 1.00263402, 0.88708112, 0.99995189, 1.01743406, 0.87473936, 0.9076109 , 1.02202715, 0.88250374, 1.06665137, 0.84538309, 0.84109731, 1.0524254 , 0.97522117, 0.94564838, 1.05965236, 0.97217503, 0.96459187, 0.9413301 , 1.00422163, 1.00733854, 0.95474848, 0.94441562, 0.89236671, 0.96775448, ... 0.9959878 , 0.93569281, 0.96401675, 0.88786078, 1.14540889, 0.8224594 , 0.84935646, 1.03698789, 1.00625543, 0.88735547, 0.99331278, 1.00797432, 0.94295773, 0.98513086, 1.0195952 , 0.88995881, 0.84278984, 1.02888997, 0.96128787, 0.91245996, 0.97871983, 0.89146682, 0.98259937, 0.95369473, 1.02821356, 0.83242344, 1.06338194, 0.82728423, 1.06433136, 0.85249613, 0.92553966, 0.96450458, 1.05280513, 0.90353168, 0.84823849, 1.03949674, 0.92214448, 0.9231072 , 0.87897527, 0.95304901, 0.91455056, 0.97220005, 0.91253068, 0.92932491, 0.85741327, 1.05336522, 1.05774423, 1.20149457, 0.99443219, 0.89727566, 0.97462237, 0.9137672 , 0.99391023, 0.98467151, 0.83221799, 1.06702143, 0.85499338, 0.94884501, 0.94337727, 0.96101538, 0.87323245, 1.02556183, 0.95388553, 0.97263382, 0.98591673, 0.96502309, 0.85746496, 0.84968585, 0.99422795, 0.89441428, 1.04297339, 1.05277335, 0.85709214, 0.87885518, 1.03047245, 1.05704007, 0.91158198, 0.91662192, 0.90469643, 0.96868723, 0.91674164, 0.93328151, 0.91403954, 1.22344839, 0.98301442, 0.97414525, 0.97886513, 0.91856841, 0.95869794, 0.92022874, 0.92707182, 0.92707182, 0.95625616, 0.93941143, 0.93244475, 0.93337728, 0.97659533, 0.97438746, 0.95742502, 0.94139699]])
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 240, 241, 242, 243, 244, 245, 246, 247, 248, 249], dtype='int64', name='draw', length=250))
<xarray.Dataset> Size: 51kB Dimensions: (chain: 4, draw: 250) Coordinates: * chain (chain) int64 32B 0 1 2 3 * draw (draw) int64 2kB 0 1 2 3 4 5 6 ... 244 245 246 247 248 249 Data variables: acceptance_rate (chain, draw) float64 8kB 0.8489 0.9674 ... 0.983 1.0 diverging (chain, draw) bool 1kB False False False ... False False energy (chain, draw) float64 8kB 141.2 140.7 140.5 ... 141.9 141.4 lp (chain, draw) float64 8kB -139.9 -140.0 ... -141.6 -140.3 n_steps (chain, draw) int64 8kB 3 7 7 3 3 3 7 3 ... 3 3 3 3 3 3 3 3 step_size (chain, draw) float64 8kB 0.8923 0.8923 ... 0.9726 0.9726 tree_depth (chain, draw) int64 8kB 2 3 3 2 2 2 3 2 ... 2 2 2 2 2 2 2 2 Attributes: created_at: 2024-04-13T05:36:20.441267+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3])
array([ 0, 1, 2, ..., 247, 248, 249])
array([[0.84888924, 0.9673831 , 0.9733441 , 0.89329671, 0.86130558, 0.70132658, 0.96381014, 1. , 0.92972425, 0.99393407, 0.42991297, 0.91174018, 0.87340076, 1. , 0.9589841 , 0.87037497, 1. , 1. , 0.99496101, 0.88726023, 0.9885737 , 0.90697071, 0.99084364, 0.80724362, 0.96962189, 0.9833609 , 0.97786349, 0.86134776, 0.90956489, 0.90854708, 0.98462356, 0.89669834, 0.90547798, 0.98699424, 0.97484471, 1. , 0.88021684, 0.97615242, 0.96073465, 0.91237892, 0.9845141 , 0.94372409, 0.92777163, 0.97342742, 0.95057722, 1. , 0.9612599 , 0.98843436, 1. , 0.8602412 , 0.99473406, 0.82449416, 0.87100299, 0.89980582, 0.91056175, 1. , 0.9853551 , 0.6339886 , 0.79539193, 0.79552437, 0.98590884, 0.7882424 , 0.95636624, 1. , 0.91644093, 0.61071389, 0.9079233 , 1. , 0.94421384, 0.89397927, 0.97441365, 0.74609534, 0.87623779, 0.65668758, 1. , 0.99754243, 1. , 1. , 0.91660863, 0.99472955, 0.7589626 , 1. , 0.79027932, 0.96149457, 0.86777121, 0.98246178, 0.78602939, 0.98883522, 0.99371046, 0.87781383, 0.90548202, 0.91736797, 0.93188359, 1. , 0.56722994, 0.92168357, 1. , 0.99974158, 0.85855488, 0.9130902 , ... 0.933183 , 0.9759323 , 0.80154439, 0.58682022, 0.89919977, 0.96192085, 1. , 1. , 1. , 0.92846989, 0.73777125, 0.78583409, 0.84084899, 0.99937813, 0.68385273, 0.90734862, 0.87463518, 0.86085163, 0.96549202, 1. , 0.73037943, 0.94028211, 1. , 0.82836097, 0.98809743, 0.81764963, 1. , 1. , 0.98798519, 0.98349015, 0.67753562, 0.95280466, 0.90283301, 0.93024677, 0.93151669, 1. , 0.96818253, 0.99781367, 0.91713181, 0.96014857, 1. , 0.75023019, 0.96045113, 0.94362692, 0.7532761 , 0.83880141, 0.87182051, 0.74171229, 0.94450692, 1. , 0.94232357, 0.86828235, 0.92197311, 1. , 0.68644282, 0.97359373, 0.90638899, 1. , 0.76594222, 0.91762793, 0.95915589, 0.87177042, 1. , 0.82258746, 0.81200426, 1. , 0.93555039, 0.8848828 , 1. , 0.87053966, 0.85639141, 0.97637026, 0.98528981, 1. , 0.85350081, 0.95414825, 1. , 1. , 0.98446748, 0.80153679, 1. , 0.62828104, 1. , 0.36846189, 0.86386322, 0.64099386, 1. , 0.84170202, 0.95361984, 0.9205592 , 0.98298569, 0.66323733, 0.85359733, 0.98170805, 0.99093396, 0.97552769, 0.95590249, 0.84734008, 0.98296291, 1. ]])
array([[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, ... False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False]])
array([[141.20161744, 140.70598666, 140.49148926, 141.20239103, 141.19010811, 144.18660929, 140.33281704, 139.96308322, 140.57893835, 141.38941515, 148.64059094, 142.93289825, 146.04299163, 144.14605027, 142.67642036, 144.14576568, 141.7660787 , 141.06947956, 140.19356762, 141.08246659, 141.00591894, 141.35859185, 139.93979496, 141.51824106, 142.56997862, 142.56991682, 142.58758086, 141.72338033, 142.8969296 , 142.42932381, 141.68924271, 143.93789916, 143.12463584, 142.28934616, 141.66880823, 140.75972524, 141.01378042, 140.13124997, 140.51835247, 141.61498126, 141.31832015, 140.66043205, 142.12540545, 141.64700326, 142.85807033, 141.70816164, 140.81111539, 140.31405788, 140.29502047, 142.17134901, 141.34942234, 141.96753663, 144.45306366, 144.30037464, 144.92560947, 145.88202322, 145.57638211, 149.53691678, 146.82447002, 143.92894787, 142.945291 , 143.62675737, 140.2703418 , 140.10974452, 140.59959866, 143.312126 , 143.4798215 , 142.32917667, 141.59167508, 140.91590845, 141.55742637, 142.42319463, 142.17191796, 145.8462414 , 143.50482308, 143.38838779, 140.93972277, 140.42443743, 141.12402307, 141.13791 , ... 141.93825369, 144.56393892, 142.73608027, 144.07625796, 142.0531902 , 141.41700837, 141.51465137, 142.62280631, 143.79817982, 142.76116047, 143.85988869, 142.11378708, 141.76697822, 141.60731201, 141.3540055 , 139.81189634, 140.0125599 , 140.24700481, 139.84972186, 141.52789778, 141.49476174, 140.70968672, 141.66856485, 142.71296198, 143.85619075, 145.52274885, 149.21963527, 147.57321087, 140.39448216, 141.853325 , 143.43444642, 141.78434647, 145.09952198, 142.36990061, 143.10356634, 140.37811064, 141.38857059, 141.31196017, 140.94430523, 142.22846388, 141.29718806, 141.2077227 , 142.47077178, 141.54519327, 140.3278173 , 140.95240769, 140.60694954, 140.58587491, 140.9740406 , 141.73032551, 142.70897439, 141.88098807, 141.65131353, 142.52300291, 143.27844069, 141.46619993, 140.87239229, 141.57738767, 140.5790518 , 143.21027453, 141.86795416, 146.14898606, 147.39883864, 145.13195593, 141.11515884, 141.95224649, 140.49781241, 140.32029504, 139.45421819, 142.18814629, 140.68012531, 139.92677613, 139.90305713, 139.92528181, 140.02423793, 141.83956437, 141.90308307, 141.4326794 ]])
array([[-139.87893778, -139.9645816 , -140.15023671, -139.72915805, -140.95230947, -139.75107499, -140.01457839, -139.4408409 , -140.52869623, -140.55137532, -140.55137532, -142.39342002, -143.70774595, -140.91220502, -142.25708851, -140.87749392, -140.94220804, -139.82075105, -139.84129544, -140.14606113, -140.18660279, -139.81366331, -139.4446184 , -141.12989093, -141.18051022, -142.24699337, -139.55294872, -141.34902122, -141.1207297 , -140.68010125, -141.41084446, -141.82800678, -141.20536354, -140.59228081, -140.69416166, -139.86811985, -139.72750674, -139.60666827, -140.26256251, -141.00322146, -139.77866637, -140.37414036, -141.03275542, -141.08023336, -141.6000328 , -140.14892777, -139.96123244, -140.06522244, -139.78981573, -141.01735455, -139.75775472, -141.22535186, -142.38556034, -141.97659086, -144.27384526, -142.67574989, -142.35722033, -143.75557052, -139.98097631, -142.70952977, -140.7217984 , -139.6410683 , -139.84543799, -139.64193129, -139.6247126 , -140.62462715, -142.18994991, -140.89554653, -139.30901481, -140.66999757, -140.09430283, -140.18511655, -141.47845726, -143.08700779, -142.63691111, -140.3748972 , -140.30603842, -139.8354723 , -140.91681837, -140.8016408 , ... -141.61328651, -142.16610216, -140.86932964, -141.66349492, -141.24899794, -140.92756748, -141.35783924, -140.46980857, -141.82709759, -141.4470513 , -141.67479038, -140.53378286, -141.14396647, -140.86380423, -139.8673658 , -139.34853841, -139.70693196, -139.74296053, -139.57680571, -140.83672532, -140.07834751, -139.67448254, -140.52919726, -141.42318554, -141.21657892, -144.92174257, -145.62381945, -139.76482799, -140.08610874, -141.02022716, -141.7757176 , -141.0336296 , -141.39137894, -141.95233948, -140.33931036, -139.38387579, -140.13741269, -140.30673995, -140.39845601, -141.29817581, -139.47208887, -140.20293725, -141.72161449, -139.43675991, -140.300563 , -140.72081646, -139.7054349 , -139.82631079, -140.69282006, -141.25803608, -141.07708316, -140.04546292, -141.29624749, -142.45655409, -141.21055678, -140.6477156 , -139.8087491 , -140.65559625, -139.34928981, -142.04465792, -139.35271669, -145.90075 , -140.47978836, -141.11111359, -140.53840869, -139.8545253 , -139.57336607, -139.29539907, -139.39810176, -139.39810176, -139.70660711, -139.62054515, -139.56840864, -139.5697715 , -139.73762165, -141.38313388, -141.58458337, -140.34263295]])
array([[ 3, 7, 7, 3, 3, 3, 7, 3, 3, 11, 3, 3, 7, 3, 7, 7, 3, 3, 7, 5, 7, 7, 3, 3, 7, 7, 7, 7, 3, 7, 7, 7, 7, 3, 7, 3, 3, 3, 3, 7, 7, 3, 3, 7, 7, 3, 7, 3, 3, 7, 3, 3, 7, 3, 3, 3, 7, 3, 3, 3, 7, 3, 3, 3, 3, 7, 7, 3, 3, 7, 7, 3, 3, 3, 3, 7, 3, 7, 3, 3, 3, 1, 7, 7, 3, 3, 7, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 7, 7, 3, 3, 3, 7, 7, 7, 7, 3, 5, 3, 1, 7, 7, 3, 3, 7, 3, 3, 15, 3, 7, 3, 3, 3, 3, 3, 3, 3, 7, 7, 7, 3, 3, 7, 7, 3, 3, 3, 7, 7, 3, 3, 3, 3, 3, 3, 3, 7, 7, 7, 3, 3, 7, 3, 5, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 7, 1, 3, 7, 7, 3, 3, 3, 3, 7, 3, 1, 3, 1, 3, 3, 7, 3, 3, 15, 3, 3, 7, 3, 3, 3, 3, 7, 3, 3, 3, 3, 15, 3, 7, 1, 3, 1, 3, 3, 7, 3, 3, 7, 3, 3, 7, 7, 7, 7, 3, 3, 7, 7, 7, 3, 7, 3, 7, 3, 3, 7, 3, 3, 3, 7, 3, 3, 7, 3, 3, 7, 3, 7, 3, 3, 7, 7, 3, 3, 3, 7, 3], [ 3, 7, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 11, 3, 3, 3, 3, 3, 7, 3, 3, 3, 1, 3, 7, 3, 3, 11, 7, 3, 3, 7, 3, 3, 19, 3, 3, 3, 3, 3, 3, 11, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, ... 3, 3, 7, 3, 3, 7, 7, 3, 3, 3, 7, 3, 7, 3, 7, 3, 3, 3, 3, 3, 3, 7, 7, 3, 3, 19, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 11, 3, 3, 3, 11, 3], [ 3, 3, 3, 3, 11, 7, 3, 3, 3, 15, 3, 3, 3, 11, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 11, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 7, 3, 3, 3, 3, 7, 3, 1, 3, 3, 3, 11, 3, 3, 3, 3, 3, 3, 3, 3, 1, 3, 3, 3, 3, 3, 7, 7, 3, 7, 7, 3, 3, 3, 11, 3, 3, 3, 3, 3, 3, 7, 7, 3, 7, 3, 3, 3, 3, 7, 3, 7, 3, 7, 3, 3, 3, 11, 3, 3, 3, 15, 3, 3, 7, 3, 3, 3, 3, 3, 3, 1, 7, 3, 3, 3, 3, 1, 3, 3, 3, 7, 3, 3, 3, 3, 3, 7, 7, 3, 15, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 11, 3, 3, 7, 3, 3, 7, 3, 3, 1, 3, 3, 3, 3, 7, 7, 3, 1, 3, 1, 3, 3, 3, 3, 7, 3, 3, 3, 3, 3, 3, 3, 7, 7, 3, 3, 3, 3, 3, 3, 3, 3, 3, 11, 3, 3, 3, 7, 7, 3, 3, 3, 3, 3, 3, 11, 3, 3, 3, 3, 3, 3, 11, 3, 1, 3, 7, 3, 3, 3, 7, 3, 3, 3, 7, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 7, 3, 7, 3, 7, 3, 3, 3, 3, 3, 3, 3, 1, 11, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])
array([[0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, 0.89234976, ... 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 , 0.9726011 ]])
array([[2, 3, 3, 2, 2, 2, 3, 2, 2, 4, 2, 2, 3, 2, 3, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 2, 3, 2, 2, 2, 2, 3, 3, 2, 2, 3, 3, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 3, 3, 2, 2, 3, 3, 2, 2, 2, 2, 3, 2, 3, 2, 2, 2, 1, 3, 3, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, 3, 2, 3, 2, 1, 3, 3, 2, 2, 3, 2, 2, 4, 2, 3, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 3, 3, 2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 3, 1, 2, 3, 3, 2, 2, 2, 2, 3, 2, 1, 2, 1, 2, 2, 3, 2, 2, 4, 2, 2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 4, 2, 3, 1, 2, 1, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 3, 2, 3, 2, 3, 2, 2, 3, 2, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 3, 2, 2, 3, 3, 2, 2, 2, 3, 2], [2, 3, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 2, 2, 2, 2, 3, 2, 2, 2, 1, 2, 3, 2, 2, 4, 3, 2, 2, 3, 2, 2, 5, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 4, 3, 2, 3, 2, 1, 2, 2, 2, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 3, 2, 3, 4, 2, 2, 1, 2, 4, 2, 2, 2, 2, 2, 2, 3, 2, 4, 2, 2, 3, 2, 2, 3, 3, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 3, 2, 3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 3, ... 2, 2, 4, 2, 3, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 3, 3, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 3, 2, 1, 3, 2, 2, 3, 2, 3, 3, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 4, 2, 1, 2, 2, 2, 3, 4, 2, 2, 4, 2, 2, 2, 2, 2, 2, 1, 3, 1, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 1, 2, 2, 3, 2, 2, 3, 3, 2, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 5, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 4, 2], [2, 2, 2, 2, 4, 3, 2, 2, 2, 4, 2, 2, 2, 4, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 3, 2, 1, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 3, 3, 2, 3, 3, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 3, 3, 2, 3, 2, 2, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 4, 2, 2, 2, 4, 2, 2, 3, 2, 2, 2, 2, 2, 2, 1, 3, 2, 2, 2, 2, 1, 2, 2, 2, 3, 2, 2, 2, 2, 2, 3, 3, 2, 4, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 3, 2, 2, 3, 2, 2, 1, 2, 2, 2, 2, 3, 3, 2, 1, 2, 1, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 4, 2, 1, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 3, 2, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 1, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])
PandasIndex(Index([0, 1, 2, 3], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 240, 241, 242, 243, 244, 245, 246, 247, 248, 249], dtype='int64', name='draw', length=250))
<xarray.Dataset> Size: 200kB Dimensions: (chain: 8, draw: 1000) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: Intercept (chain, draw) float64 64kB -0.1597 0.2011 ... 0.1525 -0.171 x (chain, draw) float64 64kB 0.2515 0.4686 0.4884 ... 0.5085 0.4896 y_sigma (chain, draw) float64 64kB 0.9735 0.8969 0.8002 ... 0.9422 1.045 Attributes: created_at: 2024-04-13T05:36:30.303342+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3, 4, 5, 6, 7])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[-0.15967084, 0.20113676, 0.01971475, ..., 0.00794969, 0.0279562 , 0.03472592], [-0.06799971, 0.14593555, -0.00604452, ..., 0.00924657, -0.10235794, -0.10236953], [ 0.043034 , 0.08600223, 0.16562574, ..., 0.05851938, 0.00720315, 0.08258778], ..., [ 0.04807806, 0.11227424, -0.3172604 , ..., 0.02980962, -0.13681545, 0.19177451], [ 0.04374417, -0.0054294 , 0.09305579, ..., 0.0232273 , -0.04073809, 0.025925 ], [-0.07370367, -0.00152223, 0.06769584, ..., -0.09818811, 0.15246738, -0.17104419]])
array([[0.25153111, 0.4685625 , 0.48837809, ..., 0.28573626, 0.407775 , 0.38347135], [0.28165967, 0.36310827, 0.41225084, ..., 0.24255857, 0.45039439, 0.4954714 ], [0.5386156 , 0.6228231 , 0.25313292, ..., 0.44280376, 0.4488854 , 0.25456354], ..., [0.45168195, 0.46344655, 0.17750331, ..., 0.30371223, 0.29536054, 0.40431303], [0.41455145, 0.43166272, 0.35213661, ..., 0.36384472, 0.3917272 , 0.34092006], [0.20620881, 0.51263399, 0.44056489, ..., 0.25237815, 0.50845624, 0.48960883]])
array([[0.97352428, 0.89691108, 0.80020873, ..., 1.03087931, 0.84944049, 0.84158909], [0.96226504, 0.92778234, 0.77909925, ..., 0.91397532, 1.00185137, 0.9513834 ], [1.0042728 , 0.97580931, 0.94890477, ..., 0.92691038, 0.885916 , 1.01934012], ..., [0.88671137, 0.91944589, 1.00541185, ..., 0.96151472, 0.93478611, 0.94631027], [0.8367989 , 0.84727656, 1.05992876, ..., 0.91519111, 0.90516942, 0.9358838 ], [0.94127918, 0.89667586, 0.91173519, ..., 0.96184559, 0.94224608, 1.04527058]])
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
<xarray.Dataset> Size: 312kB Dimensions: (chain: 8, draw: 1000) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: accept_ratio (chain, draw) float64 64kB 0.8971 0.9944 ... 0.9174 0.8133 diverging (chain, draw) bool 8kB False False False ... False False is_accepted (chain, draw) bool 8kB True True True ... True True True n_steps (chain, draw) int32 32kB 7 7 7 1 7 7 7 3 ... 3 7 7 7 3 7 7 step_size (chain, draw) float64 64kB 0.534 0.534 0.534 ... nan nan target_log_prob (chain, draw) float64 64kB -141.5 -141.7 ... -141.0 -143.2 tune (chain, draw) float64 64kB 0.0 0.0 0.0 0.0 ... nan nan nan Attributes: created_at: 2024-04-13T05:36:30.304788+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3, 4, 5, 6, 7])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[0.89706765, 0.99441475, 0.80397664, ..., 0.99407977, 1. , 0.73958291], [0.99821982, 0.95159754, 0.77731848, ..., 0.98139297, 0.91789348, 0.96456953], [0.76824526, 0.92239538, 1. , ..., 0.94414437, 0.91605876, 0.92334246], ..., [0.99710475, 0.99154725, 0.58953539, ..., 1. , 0.92397302, 0.99338491], [0.98669117, 0.98477039, 0.95831938, ..., 0.92092812, 0.96842841, 0.95013437], [0.91842649, 0.75186373, 0.99689159, ..., 1. , 0.9173519 , 0.81331846]])
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
array([[ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], ..., [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True], [ True, True, True, ..., True, True, True]])
array([[ 7, 7, 7, ..., 7, 7, 7], [ 7, 7, 3, ..., 7, 7, 3], [ 3, 3, 7, ..., 3, 15, 7], ..., [ 7, 7, 3, ..., 3, 7, 7], [ 7, 7, 7, ..., 7, 3, 3], [ 7, 7, 3, ..., 3, 7, 7]], dtype=int32)
array([[0.53403598, 0.53403598, 0.53403598, ..., 0.53403598, 0.53403598, 0.53403598], [ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], ..., [ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan], [ nan, nan, nan, ..., nan, nan, nan]])
array([[-141.50324612, -141.68928918, -142.88857248, ..., -140.47032751, -140.28212037, -140.3879345 ], [-140.01375434, -140.13505676, -143.13641583, ..., -139.99147863, -141.05078466, -141.26466845], [-141.12400308, -142.52999266, -141.18044662, ..., -139.62740626, -139.97278032, -140.75200575], ..., [-139.95465107, -140.14159326, -146.22556521, ..., -139.52421787, -140.79404952, -140.86677285], [-140.66929954, -140.59550579, -141.05595747, ..., -139.29239833, -139.67236895, -139.28664301], [-140.73991807, -140.71540356, -139.69662159, ..., -140.53100323, -140.99477405, -143.18216599]])
array([[ 0., 0., 0., ..., 0., 0., 0.], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], ..., [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]])
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
sample: 100%|██████████| 1500/1500 [00:02<00:00, 599.25it/s]
<xarray.Dataset> Size: 200kB Dimensions: (chain: 8, draw: 1000) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999 Data variables: Intercept (chain, draw) float64 64kB 0.0004764 0.02933 ... 0.1217 0.1668 x (chain, draw) float64 64kB 0.3836 0.6556 0.2326 ... 0.48 0.5808 y_sigma (chain, draw) float64 64kB 0.8821 0.9604 0.9652 ... 0.9063 0.9184 Attributes: created_at: 2024-04-13T05:36:33.599519+00:00 arviz_version: 0.18.0 inference_library: numpyro inference_library_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3, 4, 5, 6, 7])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[ 0.00047641, 0.02933426, -0.19069113, ..., -0.04607573, 0.14408182, -0.07016404], [-0.06542953, 0.09001091, 0.03811068, ..., 0.03170986, 0.20861147, 0.18706729], [-0.08028089, 0.03625393, 0.05650287, ..., -0.0093195 , 0.01912548, 0.00214345], ..., [ 0.18184083, 0.07906243, 0.06388914, ..., -0.07055763, 0.10986417, 0.09622923], [-0.02521011, 0.15830259, -0.10214413, ..., 0.01471807, 0.10706226, 0.07562878], [-0.02468806, -0.03414193, -0.06678234, ..., 0.08710519, 0.12166933, 0.16679929]])
array([[0.38361084, 0.65556045, 0.23260059, ..., 0.65580692, 0.44095681, 0.22838517], [0.30187358, 0.46285734, 0.31814527, ..., 0.38133365, 0.32358724, 0.37070791], [0.44410357, 0.42831529, 0.3990648 , ..., 0.37993575, 0.40377358, 0.42804019], ..., [0.49080324, 0.20770949, 0.12142607, ..., 0.44054445, 0.38924394, 0.38167612], [0.34590162, 0.30144285, 0.45780034, ..., 0.44424986, 0.52104263, 0.45543543], [0.23738988, 0.68021684, 0.05589656, ..., 0.42147165, 0.48000601, 0.58081686]])
array([[0.88211276, 0.96036122, 0.96524442, ..., 0.94362502, 1.00228679, 0.88249142], [0.93345676, 0.85184129, 1.07135935, ..., 0.92649839, 0.86831784, 0.92890112], [0.973364 , 1.04138907, 0.96240687, ..., 0.9564475 , 1.0092212 , 0.87607713], ..., [1.03355029, 0.98103228, 0.92902834, ..., 0.83197448, 0.99111854, 0.92967952], [0.88101923, 1.0226885 , 0.87217557, ..., 0.94028186, 0.88687764, 0.85291778], [0.98596365, 0.91083125, 0.9972831 , ..., 0.86419289, 0.90625839, 0.91841349]])
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
<xarray.Dataset> Size: 400kB Dimensions: (chain: 8, draw: 1000) Coordinates: * chain (chain) int64 64B 0 1 2 3 4 5 6 7 * draw (draw) int64 8kB 0 1 2 3 4 5 6 ... 994 995 996 997 998 999 Data variables: acceptance_rate (chain, draw) float64 64kB 0.9366 0.3542 ... 0.9903 0.8838 diverging (chain, draw) bool 8kB False False False ... False False energy (chain, draw) float64 64kB 140.3 145.4 ... 141.0 142.6 lp (chain, draw) float64 64kB 139.6 143.3 ... 140.5 142.5 n_steps (chain, draw) int64 64kB 3 3 7 3 7 1 3 3 ... 11 7 3 3 3 3 3 step_size (chain, draw) float64 64kB 0.8891 0.8891 ... 0.7595 0.7595 tree_depth (chain, draw) int64 64kB 2 2 3 2 3 1 2 2 ... 4 3 2 2 2 2 2 Attributes: created_at: 2024-04-13T05:36:33.623197+00:00 arviz_version: 0.18.0 inference_library: numpyro inference_library_version: 0.14.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([0, 1, 2, 3, 4, 5, 6, 7])
array([ 0, 1, 2, ..., 997, 998, 999])
array([[0.93661577, 0.35419612, 0.99435023, ..., 0.59003267, 1. , 0.96452433], [0.9974338 , 0.86250112, 0.95945138, ..., 0.78208773, 0.79906599, 1. ], [0.96468642, 0.97525962, 0.98495362, ..., 0.9775774 , 0.86156602, 0.89713276], ..., [0.96610793, 0.98086156, 0.89084022, ..., 0.95772457, 0.70497474, 0.99836264], [0.99806445, 0.76652499, 0.98528715, ..., 0.87560309, 0.81183609, 1. ], [0.99527811, 0.68120359, 1. , ..., 1. , 0.99026377, 0.88375821]])
array([[False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], ..., [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False], [False, False, False, ..., False, False, False]])
array([[140.3467186 , 145.35611236, 144.90015851, ..., 144.55362515, 144.91727277, 142.02764348], [141.22868211, 142.17975426, 142.8293087 , ..., 142.28333599, 142.31235728, 142.52446205], [144.0720805 , 141.47372901, 140.93450225, ..., 140.18268471, 140.91505172, 141.58829808], ..., [142.62435102, 143.2286207 , 142.80972656, ..., 144.52562661, 145.55539646, 140.15311622], [142.10722077, 142.2552196 , 142.37900379, ..., 140.54316684, 141.75774836, 141.58904973], [140.82607551, 145.03403209, 145.2241475 , ..., 141.6051566 , 140.9854462 , 142.60201809]])
array([[139.63264222, 143.32720395, 142.23861021, ..., 143.97346231, 140.67090827, 140.87592102], [139.81127204, 140.89739965, 141.16918465, ..., 139.28792859, 142.23961923, 140.80131401], [140.45054751, 140.62345763, 139.48418855, ..., 139.42120354, 139.97119948, 139.9471655 ], ..., [141.9099355 , 140.83346483, 142.30482492, ..., 141.78502364, 140.03250066, 139.57119158], [139.72932132, 141.21433158, 141.46193109, ..., 139.62199619, 141.12185494, 140.69980318], [140.29191412, 144.98589172, 143.74390375, ..., 140.2609778 , 140.48172152, 142.47099219]])
array([[3, 3, 7, ..., 3, 3, 7], [7, 3, 7, ..., 3, 3, 7], [3, 7, 3, ..., 3, 3, 3], ..., [7, 7, 7, ..., 3, 7, 7], [3, 7, 7, ..., 3, 3, 7], [7, 7, 3, ..., 3, 3, 3]])
array([[0.8891145 , 0.8891145 , 0.8891145 , ..., 0.8891145 , 0.8891145 , 0.8891145 ], [0.70896111, 0.70896111, 0.70896111, ..., 0.70896111, 0.70896111, 0.70896111], [0.8087902 , 0.8087902 , 0.8087902 , ..., 0.8087902 , 0.8087902 , 0.8087902 ], ..., [0.69745418, 0.69745418, 0.69745418, ..., 0.69745418, 0.69745418, 0.69745418], [0.88034552, 0.88034552, 0.88034552, ..., 0.88034552, 0.88034552, 0.88034552], [0.7595237 , 0.7595237 , 0.7595237 , ..., 0.7595237 , 0.7595237 , 0.7595237 ]])
array([[2, 2, 3, ..., 2, 2, 3], [3, 2, 3, ..., 2, 2, 3], [2, 3, 2, ..., 2, 2, 2], ..., [3, 3, 3, ..., 2, 3, 3], [2, 3, 3, ..., 2, 2, 3], [3, 3, 2, ..., 2, 2, 2]])
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 990, 991, 992, 993, 994, 995, 996, 997, 998, 999], dtype='int64', name='draw', length=1000))
No autotune found, use input sampler_params
Training normalizing flow
Tuning global sampler: 100%|██████████| 5/5 [00:51<00:00, 10.37s/it]
Starting Production run
Production run: 100%|██████████| 5/5 [00:00<00:00, 14.38it/s]
<xarray.Dataset> Size: 244kB Dimensions: (chain: 20, draw: 500) Coordinates: * chain (chain) int64 160B 0 1 2 3 4 5 6 7 8 ... 12 13 14 15 16 17 18 19 * draw (draw) int64 4kB 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499 Data variables: Intercept (chain, draw) float64 80kB -0.07404 -0.07404 ... -0.1455 0.09545 x (chain, draw) float64 80kB 0.4401 0.4401 0.3533 ... 0.6115 0.3824 y_sigma (chain, draw) float64 80kB 0.9181 0.9181 0.9732 ... 1.049 0.9643 Attributes: created_at: 2024-04-13T05:37:29.798250+00:00 arviz_version: 0.18.0 modeling_interface: bambi modeling_interface_version: 0.13.1.dev25+g1e7f677e.d20240413
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
array([ 0, 1, 2, ..., 497, 498, 499])
array([[-0.07403956, -0.07403956, -0.08175103, ..., -0.07505013, -0.06280065, -0.06280065], [ 0.15335447, 0.15335447, 0.12713003, ..., -0.02037623, 0.06826204, 0.01933312], [ 0.00658099, 0.00658099, 0.00658099, ..., -0.04271278, -0.04271278, -0.09780863], ..., [ 0.00629487, 0.01048304, -0.03193874, ..., 0.13237167, 0.08595727, 0.01442809], [ 0.05972149, 0.02490161, -0.00084261, ..., 0.06751994, -0.15926318, -0.15926318], [ 0.23012418, 0.25630661, 0.23839857, ..., 0.07975465, -0.14554836, 0.09545347]])
array([[0.44013865, 0.44013865, 0.35326474, ..., 0.30371128, 0.28793687, 0.28793687], [0.45569737, 0.45569737, 0.55350522, ..., 0.37498493, 0.45850535, 0.40671648], [0.24971734, 0.24971734, 0.24971734, ..., 0.19912158, 0.19912158, 0.46992411], ..., [0.44071041, 0.47684243, 0.35786393, ..., 0.37932871, 0.31101246, 0.25090813], [0.43305649, 0.19703032, 0.21622992, ..., 0.39021766, 0.35161734, 0.35161734], [0.52832789, 0.50016524, 0.19504762, ..., 0.25411208, 0.61146903, 0.38243421]])
array([[0.91812597, 0.91812597, 0.97317218, ..., 0.87193011, 0.98202548, 0.98202548], [0.92619283, 0.92619283, 0.89113835, ..., 1.00239178, 0.93585383, 0.93328517], [0.96032594, 0.96032594, 0.96032594, ..., 0.90617649, 0.90617649, 0.95241728], ..., [1.01337917, 0.96203307, 0.81645174, ..., 1.00979845, 1.07249345, 0.9165658 ], [0.97087149, 0.91876884, 0.87129204, ..., 1.09021385, 1.0093326 , 1.0093326 ], [0.89533883, 0.91515164, 1.07248889, ..., 0.95594426, 1.04908995, 0.96426064]])
PandasIndex(Index([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype='int64', name='chain'))
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 490, 491, 492, 493, 494, 495, 496, 497, 498, 499], dtype='int64', name='draw', length=500))
With ArviZ, we can compare the inference result summaries of the samplers. Note: We can’t use az.compare
as not each inference data object returns the pointwise log-probabilities. Thus, an error would be raised.
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 0.023 | 0.097 | -0.141 | 0.209 | 0.004 | 0.003 | 694.0 | 508.0 | 1.00 |
x | 0.356 | 0.111 | 0.162 | 0.571 | 0.004 | 0.003 | 970.0 | 675.0 | 1.00 |
y_sigma | 0.950 | 0.069 | 0.827 | 1.072 | 0.002 | 0.001 | 1418.0 | 842.0 | 1.01 |
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 0.023 | 0.097 | -0.157 | 0.205 | 0.001 | 0.001 | 6785.0 | 5740.0 | 1.0 |
x | 0.360 | 0.105 | 0.169 | 0.563 | 0.001 | 0.001 | 6988.0 | 5116.0 | 1.0 |
y_sigma | 0.946 | 0.067 | 0.831 | 1.081 | 0.001 | 0.001 | 7476.0 | 5971.0 | 1.0 |
mean | sd | hdi_3% | hdi_97% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
---|---|---|---|---|---|---|---|---|---|
Intercept | 0.024 | 0.095 | -0.162 | 0.195 | 0.001 | 0.001 | 6851.0 | 5614.0 | 1.0 |
x | 0.362 | 0.104 | 0.176 | 0.557 | 0.001 | 0.001 | 9241.0 | 6340.0 | 1.0 |
y_sigma | 0.946 | 0.068 | 0.826 | 1.079 | 0.001 | 0.001 | 7247.0 | 5711.0 | 1.0 |
Thanks to bayeux
, we can use three different sampling backends and 10+ alternative MCMC methods in Bambi. Using these methods is as simple as passing the inference name to the inference_method
of the fit
method.