Translating a Model into a Log Joint Probability

probability
inference
jax
Author

Gabriel Stechschulte

Published

March 26, 2023

Objective

In probabilistic programming languages (PPLs), one needs to compute the joint probability (often unnormalized) of values and observed variables under a generative model to perform approximate inference. However, given a model in the form of a Python function, how does one translate this function (model) into a log joint probability? The objective of this blog is to better understand how modern PPLs, in particular NumPyro, performs this translation in a dynamic way, i.e., the functions for performing this translation can handle a variety of models defined by the user.

The Model

The example zero_inflated_poisson.py from the NumPyro docs will be used. In this example, the authors model and predict how many fish are caught by visitors in a state park. Many groups of visitors catch zero fish, either because they did not fish at all or because they were unlucky. They explicitly model this bimodal behavior (zero versus non-zero) and ascertain which variables contribute to each behavior. The authors answer this question by fitting a zero-inflated poisson regression model. We will use NUTs as the inference method to understand the model translation.

Workflow

  1. Define model using NumPyro primitives
  2. Construct a kernel for inference and feed model into kernel
  3. Perform inference using MCMC
Code
import argparse
import os
import random

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error

import jax.numpy as jnp
from jax.random import PRNGKey
import jax.scipy as jsp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO, autoguide
from numpyro.infer import util

matplotlib.use("Agg")  # noqa: E402
Code
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
def model(X, Y):
    D_X = X.shape[1]
    b1 = numpyro.sample("b1", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1))
    b2 = numpyro.sample("b2", dist.Normal(0.0, 1.0).expand([D_X]).to_event(1))

    q = jsp.special.expit(jnp.dot(X, b1[:, None])).reshape(-1)
    lam = jnp.exp(jnp.dot(X, b2[:, None]).reshape(-1))

    with numpyro.plate("obs", X.shape[0]):
        numpyro.sample("Y", dist.ZeroInflatedPoisson(gate=q, rate=lam), obs=Y)
def run_mcmc(model, args, X, Y):
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
    )
    mcmc.run(PRNGKey(1), X, Y)
    mcmc.print_summary()
    return mcmc.get_samples()
Code
def main(args):
    set_seed(args.seed)

    # prepare dataset
    df = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
    df["intercept"] = 1
    cols = ["livebait", "camper", "persons", "child", "intercept"]

    mask = np.random.randn(len(df)) < args.train_size
    df_train = df[mask]
    df_test = df[~mask]
    X_train = jnp.asarray(df_train[cols].values)
    y_train = jnp.asarray(df_train["count"].values)
    X_test = jnp.asarray(df_test[cols].values)
    y_test = jnp.asarray(df_test["count"].values)

    print("run MCMC.")
    posterior_samples = run_mcmc(model, args, X_train, y_train)

    predictive = Predictive(model, posterior_samples=posterior_samples)
    predictions = predictive(PRNGKey(1), X=X_test, Y=None)
    mcmc_predictions = jnp.rint(predictions["Y"].mean(0))

    print(
        "MCMC RMSE: ",
        mean_squared_error(np.asarray(y_test), np.asarray(mcmc_predictions), squared=False),
    )
Code
parser = argparse.ArgumentParser("Zero-Inflated Poisson Regression")
parser.add_argument("--seed", nargs="?", default=42, type=int)
parser.add_argument("-n", "--num-samples", nargs="?", default=2000, type=int)
parser.add_argument("--num-warmup", nargs="?", default=1000, type=int)
parser.add_argument("--num-chains", nargs="?", default=1, type=int)
parser.add_argument("--num-data", nargs="?", default=100, type=int)
parser.add_argument("--maxiter", nargs="?", default=5000, type=int)
parser.add_argument("--train-size", nargs="?", default=0.8, type=float)
parser.add_argument("--device", default="cpu", type=str, help='use "cpu" or "gpu".')
args = parser.parse_args("")

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

main(args)

Initializing the kernel

First, we initialize the NUTS kernel with the model. The word kernel is used in a wide range of fields ranging from probabilistic programming, statistics, and deep learning. In PPLs, the name kernel is typically used to define the interface with the sampling algorithm. In this case, we have initialized a NUTS kernel with our model, and this kernel will allow us to interface our model with the underlying HMC sampling variant NUTS.

But, the sampling algorithm can’t simply interface with a Python function. Our model, in the form of a Python function, needs to be translated into a joint log density function and used as input into the sampler. Here, this is where NumPyro performs a series of steps to perform this translation.

When we “feed” the model into the NUTS class an initialize_model utility function is called. This function calls various helper functions such as get_potential_fn and find_valid_initial_params to return a tuple of (init_params_info, potential_fn, postprocess_fn, model_trace). Here, we are interested in initialize_model and get_potential_fn.

The graph of function calls looks like: initialize model \(\leftrightarrow\) get potential fn \(\leftrightarrow\) potential energy \(\leftrightarrow\) log density where each function being called is also returning an object. Below, the sequential order of functions calls are described.

initialize_model

initialize_model is a function that returns a tuple of objects and values used as input into the HMC algorithm. At a high level, our model and data are passed into the initialize_model function to intialize the model to some values using the observed data and numpyro.sample statements. This initialization allows us to perform inference with NUTS. Below, the various helper functions that are called within this function are described as these helpers constitute where the majority of our interest lies regarding translating a model into a log joint probability.

get_potential_fn

Inside of intialize_model, the function get_potential_fn is called. Given a model with Pyro primitives, this Python function returns another function which, given unconstrained parameters, evaluates the potential energy (negative log joint density). In addition, this returns a function to transform unconstrained values at sample sites to constrained values within their respective support.

The interesting parts here are the evaluation of potential energy and the returns a function. First, we focus on the function potential_energy to evaluate the potential energy. Later, we then return to the potential_fn object.

Code
def get_potential_fn(
    model,
    inv_transforms,
    *,
    enum=False,
    replay_model=False,
    dynamic_args=False,
    model_args=(),
    model_kwargs=None,
):
    """
    (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns a
    function which, given unconstrained parameters, evaluates the potential
    energy (negative log joint density). In addition, this returns a
    function to transform unconstrained values at sample sites to constrained
    values within their respective support.

    :param model: Python callable containing Pyro primitives.
    :param dict inv_transforms: dictionary of transforms keyed by names.
    :param bool enum: whether to enumerate over discrete latent sites.
    :param bool replay_model: whether we need to replay model in
        `postprocess_fn` to obtain `deterministic` sites.
    :param bool dynamic_args: if `True`, the `potential_fn` and
        `constraints_fn` are themselves dependent on model arguments.
        When provided a `*model_args, **model_kwargs`, they return
        `potential_fn` and `constraints_fn` callables, respectively.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :return: tuple of (`potential_fn`, `postprocess_fn`). The latter is used
        to constrain unconstrained samples (e.g. those returned by HMC)
        to values that lie within the site's support, and return values at
        `deterministic` sites in the model.
    """
    if dynamic_args:
        potential_fn = partial(
            _partial_args_kwargs, partial(potential_energy, model, enum=enum)
        )
        if replay_model:
            # XXX: we seed to sample discrete sites (but not collect them)
            model_ = seed(model.fn, 0) if enum else model
            postprocess_fn = partial(
                _partial_args_kwargs,
                partial(constrain_fn, model, return_deterministic=True),
            )
        else:
            postprocess_fn = partial(
                _drop_args_kwargs, partial(transform_fn, inv_transforms)
            )
    else:
        model_kwargs = {} if model_kwargs is None else model_kwargs
        potential_fn = partial(
            potential_energy, model, model_args, model_kwargs, enum=enum
        )
        if replay_model:
            model_ = seed(model.fn, 0) if enum else model
            postprocess_fn = partial(
                constrain_fn,
                model_,
                model_args,
                model_kwargs,
                return_deterministic=True,
            )
        else:
            postprocess_fn = partial(transform_fn, inv_transforms)

    print(f"potential_fn: {potential_fn}")
    return potential_fn, postprocess_fn

potential_energy

Computes potential energy (negative joint log density) of a model given unconstrained parameters. Under the hood, NumPyro will transform these unconstrained parameters to the values belonging to the supports of the corresponding priors in the model. To compute the potential energy, this function calls a log_density function that computes the log of joint density for the model given the latent values (parameters).

Code
def potential_energy(model, model_args, model_kwargs, params, enum=False):
    """
    (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params.
    Under the hood, we will transform these unconstrained parameters to the values
    belong to the supports of the corresponding priors in `model`.

    :param model: a callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: unconstrained parameters of `model`.
    :param bool enum: whether to enumerate over discrete latent sites.
    :return: potential energy given unconstrained parameters.
    """
    if enum:
        from numpyro.contrib.funsor import log_density as log_density_
    else:
        log_density_ = log_density

    substituted_model = substitute(
        model, substitute_fn=partial(_unconstrain_reparam, params)
    )
    # no param is needed for log_density computation because we already substitute
    log_joint, model_trace = log_density_(
        substituted_model, model_args, model_kwargs, {}
    )
    print(f"-log_joint: {log_joint}")
    return -log_joint

Given our NumPyro model, data (model_args), and initialized parameters (using numpyro.sample), the potential energy (negative log joint density) is the following output:

-log_joint: Traced<ConcreteArray([ 15.773586   15.796449   15.68768    17.106882   16.307858   15.10714
  16.277817   16.401972   17.409088   15.488239   17.44108    20.011204
  15.796449   16.401972   17.409088   15.488239   32.46398    16.306961
  18.321064   15.776845   16.432753   41.972443   16.90719    17.324219
  15.487675   15.68768    15.10714    15.68768    16.96474    61.152782
  16.319      32.46398    56.631355   16.733091   44.390583   15.796449
  31.771408   15.9031725  17.999393   15.929144   15.796449   15.776845
  25.234818   15.487675   32.46398    15.776845   22.136528   16.745249
  15.796449   15.796449   61.152782   17.459694   15.776845   39.30664
  31.771408   17.106882   15.796449   15.776845   16.836826   16.90372
  15.565512   15.266311   15.796449   15.487675   25.503807   66.416145
  42.01054    15.68768    16.438005   35.528217   16.401972  275.8356
  15.488239   46.813717   18.31475    42.01054    15.929144   16.733091
  15.929144   18.632296   16.553946   22.139755   16.879503   16.253452
  15.929144   16.68712    16.90719    17.409088   16.306961   17.900412
  72.883484   20.984446   17.080605   15.68768    15.266311   17.459694
  15.487675   17.409088  221.04407    15.487675   15.68768    15.68768
  15.903938   17.608683   17.233418   16.945618   17.102604   16.230682
  16.401972   20.437622   16.307858   15.776845   66.416145   22.85551
  17.459694   15.266729   18.263693   16.733091   15.68768    16.69443
  17.767635   16.892221   16.277817   16.699389   15.796449   15.776845
  15.68768    17.459694   18.93738    16.401972   41.471767   15.796449
  82.16476    16.664154   15.68768    17.409088   17.106882   15.488239
  15.266729   17.917303   26.629543   21.383934   18.279554   15.929144
  16.90719    38.06461    16.673416   15.487675   16.253452   15.776845
  16.306961   41.61213    15.9031725  15.488239   19.056816   30.152964
  18.068584   15.796449   15.773586   16.216064   17.080605   16.798647
  16.733091   16.307858   38.06461    15.487675   16.69443    66.416145
  16.733091   39.110504   15.266729   15.796449   16.401972   18.379236
  15.929144   16.276924   15.929144   16.306961  360.43808    15.929144
  20.011204   15.776845   15.796449   15.487675   39.291206   15.68768
  15.488239   30.152964   16.879503   15.929144   16.306961   16.69443
  16.401972   16.553946   15.796449   16.673416   17.080605  379.3062
  16.745249   17.896193   16.90719    15.266729   15.776845   15.776845 ], 
  dtype=float32)

log_density

The log_density function first uses the effect handler substitute to return a callable which substitutes all primitive calls in fn with values from data whose key matches the site name. If the site name is not present in data, then there is no side effect. After substitute, another effect handler trace is used to record inputs, distributions, and outputs of numpyro.sample statements in the model, and NumPyro primitive calls, generally speaking.

Code
def log_density(model, model_args, model_kwargs, params):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    
    model = substitute(model, data=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)

The effect handlers allow us to effectively loop through each site in the model trace to compute the joint log probability density. In the for loop, if the site type == sample grab that sites value(s) (the samples from the numpyro.sample statement) and evaluate the log probability of the value(s) for that sites fn (dist.Normal(), dist.MultivariateNormal(), etc.) with site['fn'].log_prob(<some value>) The output snippet below shows the site b1 defined in the model and the value sampled using the numpyro.sample statement.

site: {'type': 'sample', 'name': 'b1', 'fn': <numpyro.distributions.distribution.Independent object at 0x13a9ceb20>, 'args': (), 'kwargs': {'rng_key': None, 'sample_shape': ()}, 'value': Traced<ConcreteArray([ 1.2406311  -0.5222316  -1.2795658   1.800642   -0.43796206], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([ 1.2406311 , -0.5222316 , -1.2795658 ,  1.800642  , -0.43796206],      dtype=float32)
  tangent = Traced<ShapedArray(float32[5])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[5]), None)
    recipe = LambdaBinding(), 'scale': None, 'is_observed': False, 'intermediates': [], 'cond_indep_stack': [], 'infer': {}} 

value: Traced<ConcreteArray([ 1.2406311  -0.5222316  -1.2795658   1.800642   -0.43796206], dtype=float32)>with<JVPTrace(level=2/0)> with
  primal = Array([ 1.2406311 , -0.5222316 , -1.2795658 ,  1.800642  , -0.43796206],      dtype=float32)
  tangent = Traced<ShapedArray(float32[5])>with<JaxprTrace(level=1/0)> with
    pval = (ShapedArray(float32[5]), None)
    recipe = LambdaBinding(), 

Subsequently, we can also see the fn of this sample site defined in our model:

site fn: <numpyro.distributions.distribution.Independent object at 0x13a9ceb20>

where the fn is an Independent Normal distribution because we called the .to_event(1) method in our model. Next, the log probability for the sample site value is computed by calling the .log_prob() method. For example, the log probability of the sampled values for site b1 is:

log prob.    = Traced<ConcreteArray(-8.036343574523926, dtype=float32)

Subsequently, we sum over the log probability for that site. Lastly, the variable log_joint is created for the log joint probability density, and the current site log probability is added to the log joint probability. After looping through each sample site, the log joint then represents the log joint probability density for the model given the latent values (parameters).

log joint         = Traced<ConcreteArray([ -15.773586   -15.796449   -15.68768    -17.106882   -16.307858
  -15.10714    -16.277817   -16.401972   -17.409088   -15.488239
  -17.44108    -20.011204   -15.796449   -16.401972   -17.409088
  -15.488239   -32.46398    -16.306961   -18.321064   -15.776845
  -16.432753   -41.972443   -16.90719    -17.324219   -15.487675
  -15.68768    -15.10714    -15.68768    -16.96474    -61.152782
  -16.319      -32.46398    -56.631355   -16.733091   -44.390583
  -15.796449   -31.771408   -15.9031725  -17.999393   -15.929144
  -15.796449   -15.776845   -25.234818   -15.487675   -32.46398
  -15.776845   -22.136528   -16.745249   -15.796449   -15.796449
  -61.152782   -17.459694   -15.776845   -39.30664    -31.771408
  -17.106882   -15.796449   -15.776845   -16.836826   -16.90372
  -15.565512   -15.266311   -15.796449   -15.487675   -25.503807
  -66.416145   -42.01054    -15.68768    -16.438005   -35.528217
  -16.401972  -275.8356     -15.488239   -46.813717   -18.31475
  -42.01054    -15.929144   -16.733091   -15.929144   -18.632296
  -16.553946   -22.139755   -16.879503   -16.253452   -15.929144
  -16.68712    -16.90719    -17.409088   -16.306961   -17.900412
  -72.883484   -20.984446   -17.080605   -15.68768    -15.266311
  -17.459694   -15.487675   -17.409088  -221.04407    -15.487675
  -15.68768    -15.68768    -15.903938   -17.608683   -17.233418
  -16.945618   -17.102604   -16.230682   -16.401972   -20.437622
  -16.307858   -15.776845   -66.416145   -22.85551    -17.459694
  -15.266729   -18.263693   -16.733091   -15.68768    -16.69443
  -17.767635   -16.892221   -16.277817   -16.699389   -15.796449
  -15.776845   -15.68768    -17.459694   -18.93738    -16.401972
  -41.471767   -15.796449   -82.16476    -16.664154   -15.68768
  -17.409088   -17.106882   -15.488239   -15.266729   -17.917303
  -26.629543   -21.383934   -18.279554   -15.929144   -16.90719
  -38.06461    -16.673416   -15.487675   -16.253452   -15.776845
  -16.306961   -41.61213    -15.9031725  -15.488239   -19.056816
  -30.152964   -18.068584   -15.796449   -15.773586   -16.216064
  -17.080605   -16.798647   -16.733091   -16.307858   -38.06461
  -15.487675   -16.69443    -66.416145   -16.733091   -39.110504
  -15.266729   -15.796449   -16.401972   -18.379236   -15.929144
  -16.276924   -15.929144   -16.306961  -360.43808    -15.929144
  -20.011204   -15.776845   -15.796449   -15.487675   -39.291206
  -15.68768    -15.488239   -30.152964   -16.879503   -15.929144
  -16.306961   -16.69443    -16.401972   -16.553946   -15.796449
  -16.673416   -17.080605  -379.3062     -16.745249   -17.896193
  -16.90719    -15.266729   -15.776845   -15.776845 ], dtype=float32)

And voila, this output is the log joint probability density (and placing a negative sign in front of this array gives you the potential energy) for the model given the latent values. The values in the output above represent the log joint probability of the initialized latent valus, i.e., no inference has been ran yet.

Code
def log_density(model, model_args, model_kwargs, params):
    """
    (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given
    latent values ``params``.

    :param model: Python callable containing NumPyro primitives.
    :param tuple model_args: args provided to the model.
    :param dict model_kwargs: kwargs provided to the model.
    :param dict params: dictionary of current parameter values keyed by site
        name.
    :return: log of joint density and a corresponding model trace
    """
    
    model = substitute(model, data=params)
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
    log_joint = jnp.zeros(())
    print("---- inside log_density -----")
    print('\n')
    for site in model_trace.values():
        print(f"site: {site}", '\n')
        if site["type"] == "sample":
            value = site["value"]
            print(f"value: {value}, \n")
            intermediates = site["intermediates"]
            print(f"intermediates: {intermediates}, \n")
            scale = site["scale"]
            print(f"site fn: {site['fn']}, \n")
            if intermediates:
                log_prob = site["fn"].log_prob(value, intermediates)
            else:
                guide_shape = jnp.shape(value)
                model_shape = tuple(
                    site["fn"].shape()
                )  # TensorShape from tfp needs casting to tuple
                try:
                    broadcast_shapes(guide_shape, model_shape)
                except ValueError:
                    raise ValueError(
                        "Model and guide shapes disagree at site: '{}': {} vs {}".format(
                            site["name"], model_shape, guide_shape
                        )
                    )
                log_prob = site["fn"].log_prob(value)

            if (scale is not None) and (not is_identically_one(scale)):
                log_prob = scale * log_prob

            # print(f"before sum log prob.    = {log_prob}, \n")
            # log_prob = jnp.sum(log_prob)
            # print(f"after sum log prob.     = {log_prob}, \n")
            log_joint = log_joint + log_prob
            print(f"log joint               = {log_joint}, \n")
    return log_joint, model_trace

Returning to get_potential_fn

However, log_density and potential_energy compute the log probability of the current latent value, not a function. Therefore, we return to the Python function get_potential_fn that returns the log joint probability as a function potential_fn, i.e., a function that will evaluate the potential energy given the model args defined by our model, i.e., X_train and y_train and the latent values using the log_density function described above. The potential_fn is then “fed” into the HMC algorithm to perform inference.

What’s great about get_potential_fn is that the log joint probability density function can be accessed externally given our NumPyro model, and passed into other sampling libraries such as Blackjax.

Summary

To translate a NumPyro model, in the form of a Python function, to a joint probability (function and value) a series of function calls are performed. Namely, initialize model \(\leftrightarrow\) get potential fn \(\leftrightarrow\) potential energy \(\leftrightarrow\) log density.

  • log_density computes the log probability of the current latent (parameter) value and observed data
  • potential_energy is -log_density
  • get_potential_fn returns the log joint probability as a function potential_fn, i.e., a function that will evaluate the potential energy given the model args defined by our model, i.e., X_train and y_train and the latent values using the log_density function. The potential_fn is then “fed” into the HMC algorithm to perform inference.

This process allows users to (1) translate their models in a dynamic manner, and (2) access the log joint probability for external use, such as in Blackjax.