In probabilistic programming languages (PPLs), one needs to compute the joint probability (often unnormalized) of values and observed variables under a generative model to perform approximate inference. However, given a model in the form of a Python function, how does one translate this function (model) into a log joint probability? The objective of this blog is to better understand how modern PPLs, in particular NumPyro, performs this translation in a dynamic way, i.e., the functions for performing this translation can handle a variety of models defined by the user.
The Model
The example zero_inflated_poisson.py from the NumPyro docs will be used. In this example, the authors model and predict how many fish are caught by visitors in a state park. Many groups of visitors catch zero fish, either because they did not fish at all or because they were unlucky. They explicitly model this bimodal behavior (zero versus non-zero) and ascertain which variables contribute to each behavior. The authors answer this question by fitting a zero-inflated poisson regression model. We will use NUTs as the inference method to understand the model translation.
Workflow
Define model using NumPyro primitives
Construct a kernel for inference and feed model into kernel
Perform inference using MCMC
Code
import argparseimport osimport randomimport matplotlibimport matplotlib.pyplot as pltimport numpy as npimport pandas as pdfrom sklearn.metrics import mean_squared_errorimport jax.numpy as jnpfrom jax.random import PRNGKeyimport jax.scipy as jspimport numpyroimport numpyro.distributions as distfrom numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguidefrom numpyro.infer import utilmatplotlib.use("Agg") # noqa: E402
First, we initialize the NUTS kernel with the model. The word kernel is used in a wide range of fields ranging from probabilistic programming, statistics, and deep learning. In PPLs, the name kernel is typically used to define the interface with the sampling algorithm. In this case, we have initialized a NUTS kernel with our model, and this kernel will allow us to interface our model with the underlying HMC sampling variant NUTS.
But, the sampling algorithm can’t simply interface with a Python function. Our model, in the form of a Python function, needs to be translated into a joint log density function and used as input into the sampler. Here, this is where NumPyro performs a series of steps to perform this translation.
When we “feed” the model into the NUTS class an initialize_model utility function is called. This function calls various helper functions such as get_potential_fn and find_valid_initial_params to return a tuple of (init_params_info, potential_fn, postprocess_fn, model_trace). Here, we are interested in initialize_model and get_potential_fn.
The graph of function calls looks like: initialize model \(\leftrightarrow\) get potential fn \(\leftrightarrow\) potential energy \(\leftrightarrow\) log density where each function being called is also returning an object. Below, the sequential order of functions calls are described.
initialize_model
initialize_model is a function that returns a tuple of objects and values used as input into the HMC algorithm. At a high level, our model and data are passed into the initialize_model function to intialize the model to some values using the observed data and numpyro.sample statements. This initialization allows us to perform inference with NUTS. Below, the various helper functions that are called within this function are described as these helpers constitute where the majority of our interest lies regarding translating a model into a log joint probability.
get_potential_fn
Inside of intialize_model, the function get_potential_fn is called. Given a model with Pyro primitives, this Python function returns another function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support.
The interesting parts here are the evaluation of potential energy and the returns a function. First, we focus on the function potential_energy to evaluate the potential energy. Later, we then return to the potential_fn object.
Code
def get_potential_fn( model, inv_transforms,*, enum=False, replay_model=False, dynamic_args=False, model_args=(), model_kwargs=None,):""" (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support. :param model: Python callable containing Pyro primitives. :param dict inv_transforms: dictionary of transforms keyed by names. :param bool enum: whether to enumerate over discrete latent sites. :param bool replay_model: whether we need to replay model in `postprocess_fn` to obtain `deterministic` sites. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`potential_fn`, `postprocess_fn`). The latter is used to constrain unconstrained samples (e.g. those returned by HMC) to values that lie within the site's support, and return values at `deterministic` sites in the model. """if dynamic_args: potential_fn = partial( _partial_args_kwargs, partial(potential_energy, model, enum=enum) )if replay_model:# XXX: we seed to sample discrete sites (but not collect them) model_ = seed(model.fn, 0) if enum else model postprocess_fn = partial( _partial_args_kwargs, partial(constrain_fn, model, return_deterministic=True), )else: postprocess_fn = partial( _drop_args_kwargs, partial(transform_fn, inv_transforms) )else: model_kwargs = {} if model_kwargs isNoneelse model_kwargs potential_fn = partial( potential_energy, model, model_args, model_kwargs, enum=enum )if replay_model: model_ = seed(model.fn, 0) if enum else model postprocess_fn = partial( constrain_fn, model_, model_args, model_kwargs, return_deterministic=True, )else: postprocess_fn = partial(transform_fn, inv_transforms)print(f"potential_fn: {potential_fn}")return potential_fn, postprocess_fn
potential_energy
Computes potential energy (negative joint log density) of a model given unconstrained parameters. Under the hood, NumPyro will transform these unconstrained parameters to the values belonging to the supports of the corresponding priors in the model. To compute the potential energy, this function calls a log_density function that computes the log of joint density for the model given the latent values (parameters).
Code
def potential_energy(model, model_args, model_kwargs, params, enum=False):""" (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. Under the hood, we will transform these unconstrained parameters to the values belong to the supports of the corresponding priors in `model`. :param model: a callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: unconstrained parameters of `model`. :param bool enum: whether to enumerate over discrete latent sites. :return: potential energy given unconstrained parameters. """if enum:from numpyro.contrib.funsor import log_density as log_density_else: log_density_ = log_density substituted_model = substitute( model, substitute_fn=partial(_unconstrain_reparam, params) )# no param is needed for log_density computation because we already substitute log_joint, model_trace = log_density_( substituted_model, model_args, model_kwargs, {} )print(f"-log_joint: {log_joint}")return-log_joint
Given our NumPyro model, data (model_args), and initialized parameters (using numpyro.sample), the potential energy (negative log joint density) is the following output:
The log_density function first uses the effect handler substitute to return a callable which substitutes all primitive calls in fn with values from data whose key matches the site name. If the site name is not present in data, then there is no side effect. After substitute, another effect handler trace is used to record inputs, distributions, and outputs of numpyro.sample statements in the model, and NumPyro primitive calls, generally speaking.
Code
def log_density(model, model_args, model_kwargs, params):""" (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs)
The effect handlers allow us to effectively loop through each site in the model trace to compute the joint log probability density. In the for loop, if the site type == sample grab that sites value(s) (the samples from the numpyro.sample statement) and evaluate the log probability of the value(s) for that sites fn (dist.Normal(), dist.MultivariateNormal(), etc.) with site['fn'].log_prob(<some value>) The output snippet below shows the site b1 defined in the model and the value sampled using the numpyro.sample statement.
Subsequently, we can also see the fn of this sample site defined in our model:
site fn: <numpyro.distributions.distribution.Independent object at 0x13a9ceb20>
where the fn is an Independent Normal distribution because we called the .to_event(1) method in our model. Next, the log probability for the sample site value is computed by calling the .log_prob() method. For example, the log probability of the sampled values for site b1 is:
Subsequently, we sum over the log probability for that site. Lastly, the variable log_joint is created for the log joint probability density, and the current site log probability is added to the log joint probability. After looping through each sample site, the log joint then represents the log joint probability density for the model given the latent values (parameters).
And voila, this output is the log joint probability density (and placing a negative sign in front of this array gives you the potential energy) for the model given the latent values. The values in the output above represent the log joint probability of the initialized latent valus, i.e., no inference has been ran yet.
Code
def log_density(model, model_args, model_kwargs, params):""" (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :return: log of joint density and a corresponding model trace """ model = substitute(model, data=params) model_trace = trace(model).get_trace(*model_args, **model_kwargs) log_joint = jnp.zeros(())print("---- inside log_density -----")print('\n')for site in model_trace.values():print(f"site: {site}", '\n')if site["type"] =="sample": value = site["value"]print(f"value: {value}, \n") intermediates = site["intermediates"]print(f"intermediates: {intermediates}, \n") scale = site["scale"]print(f"site fn: {site['fn']}, \n")if intermediates: log_prob = site["fn"].log_prob(value, intermediates)else: guide_shape = jnp.shape(value) model_shape =tuple( site["fn"].shape() ) # TensorShape from tfp needs casting to tupletry: broadcast_shapes(guide_shape, model_shape)exceptValueError:raiseValueError("Model and guide shapes disagree at site: '{}': {} vs {}".format( site["name"], model_shape, guide_shape ) ) log_prob = site["fn"].log_prob(value)if (scale isnotNone) and (not is_identically_one(scale)): log_prob = scale * log_prob# print(f"before sum log prob. = {log_prob}, \n")# log_prob = jnp.sum(log_prob)# print(f"after sum log prob. = {log_prob}, \n") log_joint = log_joint + log_probprint(f"log joint = {log_joint}, \n")return log_joint, model_trace
Returning to get_potential_fn
However, log_density and potential_energy compute the log probability of the current latent value, not a function. Therefore, we return to the Python function get_potential_fn that returns the log joint probability as a function potential_fn, i.e., a function that will evaluate the potential energy given the model args defined by our model, i.e., X_train and y_train and the latent values using the log_density function described above. The potential_fn is then “fed” into the HMC algorithm to perform inference.
What’s great about get_potential_fn is that the log joint probability density function can be accessed externally given our NumPyro model, and passed into other sampling libraries such as Blackjax.
Summary
To translate a NumPyro model, in the form of a Python function, to a joint probability (function and value) a series of function calls are performed. Namely, initialize model \(\leftrightarrow\) get potential fn \(\leftrightarrow\) potential energy \(\leftrightarrow\) log density.
log_density computes the log probability of the current latent (parameter) value and observed data
potential_energy is -log_density
get_potential_fn returns the log joint probability as a function potential_fn, i.e., a function that will evaluate the potential energy given the model args defined by our model, i.e., X_train and y_train and the latent values using the log_density function. The potential_fn is then “fed” into the HMC algorithm to perform inference.
This process allows users to (1) translate their models in a dynamic manner, and (2) access the log joint probability for external use, such as in Blackjax.