pyro.contrib.funsor,Pyro 的新后端 - 构建推断算法(第二部分)

[1]:
from collections import OrderedDict
import functools

import torch
from torch.distributions import constraints

import funsor

from pyro import set_rng_seed as pyro_set_rng_seed
from pyro.ops.indexing import Vindex
from pyro.poutine.messenger import Messenger

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)

简介

在本教程的第一部分,我们介绍了 Pyro 的新后端 pyro.contrib.funsor

在这里,我们将介绍如何使用 pyro.contrib.funsor 中的组件从零开始实现一个变量消除推断算法。本教程假设读者熟悉 Pyro 中基于枚举的推断算法。有关背景知识和动机,读者应查阅枚举教程

和之前一样,我们将使用 pyroapi,以便我们可以使用标准的 Pyro 语法编写模型。

[2]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import infer, handlers, ops, optim, pyro
from pyroapi import distributions as dist

在整个过程中,我们将使用以下模型。它是一个离散状态、连续观测的隐马尔可夫模型,具有可学习的转移和发射分布,这些分布依赖于一个全局随机变量。

[3]:
data = [torch.tensor(1.)] * 10

def model(data, verbose):

    p = pyro.param("probs", lambda: torch.rand((3, 3)), constraint=constraints.simplex)
    locs_mean = pyro.param("locs_mean", lambda: torch.ones((3,)))
    locs = pyro.sample("locs", dist.Normal(locs_mean, 1.).to_event(1))
    if verbose:
        print("locs.shape = {}".format(locs.shape))

    x = 0
    for i in pyro.markov(range(len(data))):
        x = pyro.sample("x{}".format(i), dist.Categorical(p[x]), infer={"enumerate": "parallel"})
        if verbose:
            print("x{}.shape = ".format(i), x.shape)
        pyro.sample("y{}".format(i), dist.Normal(Vindex(locs)[..., x], 1.), obs=data[i])

我们可以使用 pyroapi 在默认的 Pyro 后端和新的 contrib.funsor 后端下运行 model

[4]:
# default backend: "pyro"
with pyroapi.pyro_backend("pyro"):
    model(data, verbose=True)

# new backend: "contrib.funsor"
with pyroapi.pyro_backend("contrib.funsor"):
    model(data, verbose=True)
locs.shape = torch.Size([3])
x0.shape =  torch.Size([])
x1.shape =  torch.Size([])
x2.shape =  torch.Size([])
x3.shape =  torch.Size([])
x4.shape =  torch.Size([])
x5.shape =  torch.Size([])
x6.shape =  torch.Size([])
x7.shape =  torch.Size([])
x8.shape =  torch.Size([])
x9.shape =  torch.Size([])
locs.shape = torch.Size([3])
x0.shape =  torch.Size([])
x1.shape =  torch.Size([])
x2.shape =  torch.Size([])
x3.shape =  torch.Size([])
x4.shape =  torch.Size([])
x5.shape =  torch.Size([])
x6.shape =  torch.Size([])
x7.shape =  torch.Size([])
x8.shape =  torch.Size([])
x9.shape =  torch.Size([])

枚举离散变量

我们的第一步是实现一个 effect handler,用于对离散潜变量进行并行枚举。在这里,我们将实现 pyro.poutine.enum 的一个简化版本,它是 Pyro 最强大的通用推断算法 pyro.infer.TraceEnum_ELBOpyro.infer.mcmc.HMC 背后的 effect handler。

我们将通过构建一个表示每个离散潜变量支持集的 funsor.Tensor,并使用第一部分中新的 pyro.to_data 原语将其转换为具有适当形状的 torch.Tensor 来实现这一点。

[5]:
from pyro.contrib.funsor.handlers.named_messenger import NamedMessenger

class EnumMessenger(NamedMessenger):

    @pyroapi.pyro_backend("contrib.funsor")  # necessary since we invoke pyro.to_data and pyro.to_funsor
    def _pyro_sample(self, msg):
        if msg["done"] or msg["is_observed"] or msg["infer"].get("enumerate") != "parallel":
            return

        # We first compute a raw value using the standard enumerate_support method.
        # enumerate_support returns a value of shape:
        #     (support_size,) + (1,) * len(msg["fn"].batch_shape).
        raw_value = msg["fn"].enumerate_support(expand=False)

        # Next we'll use pyro.to_funsor to indicate that this dimension is fresh.
        # This is guaranteed because we use msg['name'], the name of this pyro.sample site,
        # as the name for this positional dimension, and sample site names must be unique.
        funsor_value = pyro.to_funsor(
            raw_value,
            output=funsor.Bint[raw_value.shape[0]],
            dim_to_name={-raw_value.dim(): msg["name"]},
        )

        # Finally, we convert the value back to a PyTorch tensor with to_data,
        # which has the effect of reshaping and possibly permuting dimensions of raw_value.
        # Applying to_funsor and to_data in this way guarantees that
        # each enumerated random variable gets a unique fresh positional dimension
        # and that we can convert the model's log-probability tensors to funsor.Tensors
        # in a globally consistent manner.
        msg["value"] = pyro.to_data(funsor_value)
        msg["done"] = True

由于这是一个入门教程,EnumMessenger 的这个实现直接使用 site 的 PyTorch 分布,因为熟悉 PyTorch 和 Pyro 的用户可能会觉得更容易理解。然而,在更真实的场景中使用 contrib.funsor 实现推断算法时,通常最好在 funsors 上进行尽可能多的计算,因为这往往会简化复杂的索引、广播或形状操作逻辑。

例如,在 EnumMessenger 中,我们可能改为在 msg["fn"] 上调用 pyro.to_funsor

funsor_dist = pyro.to_funsor(msg["fn"], output=funsor.Real)(value=msg["name"])
# enumerate_support defined whenever isinstance(funsor_dist, funsor.distribution.Distribution)
funsor_value = funsor_dist.enumerate_support(expand=False)
raw_value = pyro.to_data(funsor_value)

pyro.contrib.funsor 中实现的大多数更完整的推断算法都遵循这种模式,我们将在本教程后面看到一个示例。在继续之前,让我们看看 EnumMessenger 对模型中随机变量的形状有什么影响。

[6]:
with pyroapi.pyro_backend("contrib.funsor"), \
        EnumMessenger():
    model(data, True)
locs.shape = torch.Size([3])
x0.shape =  torch.Size([3, 1, 1, 1, 1])
x1.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x2.shape =  torch.Size([3, 1, 1, 1, 1])
x3.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x4.shape =  torch.Size([3, 1, 1, 1, 1])
x5.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x6.shape =  torch.Size([3, 1, 1, 1, 1])
x7.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x8.shape =  torch.Size([3, 1, 1, 1, 1])
x9.shape =  torch.Size([3, 1, 1, 1, 1, 1])

跨多个样本向量化模型

接下来,由于我们对全局变量的先验是连续的,无法精确枚举,我们将实现一个 effect handler,它使用一个全局维度从模型中并行抽取多个样本。我们的实现将使用 pyro.to_data 分配一个新的粒子维度,就像上面的 EnumMessenger 一样,但与枚举维度不同的是,我们希望粒子维度在所有样本 site 之间共享,因此我们在调用 pyro.to_funsor 时会将其标记为 DimType.GLOBAL 维度。

回想一下,在第一部分中,我们看到 DimType.GLOBAL 维度必须手动释放,否则它们将一直存在直到最终的 effect handler 退出。这个底层细节由 pyro.contrib.funsor 中提供的 GlobalNameMessenger handler 自动处理,它是分配全局维度的任何 effect handler 的基类。我们的向量化 effect handler 将继承自这个类。

[7]:
from pyro.contrib.funsor.handlers.named_messenger import GlobalNamedMessenger
from pyro.contrib.funsor.handlers.runtime import DimRequest, DimType

class VectorizeMessenger(GlobalNamedMessenger):

    def __init__(self, size, name="_PARTICLES"):
        super().__init__()
        self.name = name
        self.size = size

    @pyroapi.pyro_backend("contrib.funsor")
    def _pyro_sample(self, msg):
        if msg["is_observed"] or msg["done"] or msg["infer"].get("enumerate") == "parallel":
            return

        # we'll first draw a raw batch of samples similarly to EnumMessenger.
        # However, since we are drawing a single batch from the joint distribution,
        # we don't need to take multiple samples if the site is already batched.
        if self.name in pyro.to_funsor(msg["fn"], funsor.Real).inputs:
            raw_value = msg["fn"].rsample()
        else:
            raw_value = msg["fn"].rsample(sample_shape=(self.size,))

        # As before, we'll use pyro.to_funsor to register the new dimension.
        # This time, we indicate that the particle dimension should be treated as a global dimension.
        fresh_dim = len(msg["fn"].event_shape) - raw_value.dim()
        funsor_value = pyro.to_funsor(
            raw_value,
            output=funsor.Reals[tuple(msg["fn"].event_shape)],
            dim_to_name={fresh_dim: DimRequest(value=self.name, dim_type=DimType.GLOBAL)},
        )

        # finally, convert the sample to a PyTorch tensor using to_data as before
        msg["value"] = pyro.to_data(funsor_value)
        msg["done"] = True

让我们看看 VectorizeMessengermodel 中值的形状有什么影响。

[8]:
with pyroapi.pyro_backend("contrib.funsor"), \
        VectorizeMessenger(size=10):
    model(data, verbose=True)
locs.shape = torch.Size([10, 1, 1, 1, 1, 3])
x0.shape =  torch.Size([])
x1.shape =  torch.Size([])
x2.shape =  torch.Size([])
x3.shape =  torch.Size([])
x4.shape =  torch.Size([])
x5.shape =  torch.Size([])
x6.shape =  torch.Size([])
x7.shape =  torch.Size([])
x8.shape =  torch.Size([])
x9.shape =  torch.Size([])

现在结合 EnumMessenger 使用

[9]:
with pyroapi.pyro_backend("contrib.funsor"), \
        VectorizeMessenger(size=10), EnumMessenger():
    model(data, verbose=True)
locs.shape = torch.Size([10, 1, 1, 1, 1, 3])
x0.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x1.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x2.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x3.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x4.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x5.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x6.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x7.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])
x8.shape =  torch.Size([3, 1, 1, 1, 1, 1])
x9.shape =  torch.Size([3, 1, 1, 1, 1, 1, 1])

使用变量消除计算 ELBO

现在我们有了枚举离散变量和抽取批量样本的工具,我们可以利用这些工具计算推断算法中感兴趣的量。

Pyro 中的大多数推断算法都使用 pyro.poutine.Trace,这是一种自定义数据结构,包含参数、样本 site 分布和值,以及推断计算所需的所有相关元数据。我们的第三个 effect handler LogJointMessenger 偏离了这种设计模式,从而消除了大量的样板代码。它将自动为模型的联合概率密度的对数构建一个惰性的 Funsor 表达式;当使用 Trace 时,必须手动调用 Trace.compute_log_probs() 并从 trace 中生成的单个对数概率张量中急切地计算目标函数来触发此过程。

在我们对 LogJointMessenger 的实现中,与前两个 effect handler 不同,我们将在样本值和分布上都调用 pyro.to_funsor,以展示包括对数概率密度评估在内的几乎所有推断操作如何可以直接在 funsor.Funsor 上执行。

[10]:
class LogJointMessenger(Messenger):

    def __enter__(self):
        self.log_joint = funsor.Number(0.)
        return super().__enter__()

    @pyroapi.pyro_backend("contrib.funsor")
    def _pyro_post_sample(self, msg):

        # for Monte Carlo-sampled variables, we don't include a log-density term:
        if not msg["is_observed"] and not msg["infer"].get("enumerate"):
            return

        with funsor.interpreter.interpretation(funsor.terms.lazy):
            funsor_dist = pyro.to_funsor(msg["fn"], output=funsor.Real)
            funsor_value = pyro.to_funsor(msg["value"], output=funsor_dist.inputs["value"])
            self.log_joint += funsor_dist(value=funsor_value)

最后是实际的损失函数,它应用了我们的三个 effect handler 来计算对数密度表达式,使用 funsor.ops.logaddexp 对离散变量进行边际化,使用 funsor.ops.add 对蒙特卡罗样本进行平均,并使用 Funsor 的 optimize 解释器对最终的惰性表达式进行评估,以实现变量消除。

请注意,log_z 精确地折叠了模型的局部离散潜变量,但对于任何连续潜变量来说,它是一个 ELBO,因此等同于一个没有 guide 的简单版本的 TraceEnum_ELBO

[11]:
@pyroapi.pyro_backend("contrib.funsor")
def log_z(model, model_args, size=10):
    with LogJointMessenger() as tr, \
            VectorizeMessenger(size=size) as v, \
            EnumMessenger():
        model(*model_args)

    with funsor.interpreter.interpretation(funsor.terms.lazy):
        prod_vars = frozenset({v.name})
        sum_vars = frozenset(tr.log_joint.inputs) - prod_vars

        # sum over the discrete random variables we enumerated
        expr = tr.log_joint.reduce(funsor.ops.logaddexp, sum_vars)

        # average over the sample dimension
        expr = expr.reduce(funsor.ops.add, prod_vars) - funsor.Number(float(size))

    return pyro.to_data(funsor.optimizer.apply_optimizer(expr))

总结

最后,通过实现所有这些机制,我们可以计算关于 ELBO 的随机梯度。

[12]:
with pyroapi.pyro_backend("contrib.funsor"):
    model(data, verbose=False)  # initialize parameters
    params = [pyro.param("probs").unconstrained(), pyro.param("locs_mean").unconstrained()]

optimizer = torch.optim.Adam(params, lr=0.1)
for step in range(5):
    optimizer.zero_grad()
    log_marginal = log_z(model, (data, False))
    (-log_marginal).backward()
    optimizer.step()
    print(log_marginal)
tensor(-133.6274, grad_fn=<AddBackward0>)
tensor(-129.2379, grad_fn=<AddBackward0>)
tensor(-125.9609, grad_fn=<AddBackward0>)
tensor(-123.7484, grad_fn=<AddBackward0>)
tensor(-122.3034, grad_fn=<AddBackward0>)
[ ]: