pyro.contrib.funsor, Pyro 的一个新后端 - 新原语 (第一部分)

简介

在本教程中,我们将介绍 pyro.contrib.funsor 的基础知识,这是 Pyro 概率编程系统的一个新后端,旨在取代 Pyro 当前的内部机制,并显著扩展其作为建模工具和推断研究平台的能力。

本教程面向那些有兴趣开发自定义推断算法并理解 Pyro 当前和未来的内部机制的读者。因此,本文假设您对通用 Pyro API 包 pyroapi 以及 Funsor 有一定的了解。Funsor 的额外文档可以在Pyro 网站GitHub 上以及研究论文“Functional Tensors for Probabilistic Programming.” 中找到。对这些细节不太感兴趣的读者会发现他们已经可以通过 pyroapi 使用 contrib.funsor 中的通用算法来处理他们现有的 Pyro 模型。

使用 pyroapi 重新解释现有的 Pyro 模型

新后端使用 pyroapi 包与现有 Pyro 代码集成。

首先,我们导入一些依赖项

[1]:
from collections import OrderedDict

import torch
import funsor
from pyro import set_rng_seed as pyro_set_rng_seed

funsor.set_backend("torch")
torch.set_default_dtype(torch.float32)
pyro_set_rng_seed(101)

导入 pyro.contrib.funsor 会将 "contrib.funsor" 后端注册到 pyroapi,现在可以将其作为参数传递给 pyroapi.pyro_backend 上下文管理器。

[2]:
import pyro.contrib.funsor
import pyroapi
from pyroapi import handlers, infer, ops, optim, pyro
from pyroapi import distributions as dist

# this is already done in pyro.contrib.funsor, but we repeat it here
pyroapi.register_backend("contrib.funsor", dict(
    distributions="pyro.distributions",
    handlers="pyro.contrib.funsor.handlers",
    infer="pyro.contrib.funsor.infer",
    ops="torch",
    optim="pyro.optim",
    pyro="pyro.contrib.funsor",
))

就这样开始了!从现在起,任何 pyro.(...) 语句应被理解为分派给新后端。

两个新原语:to_funsorto_data

pyro.contrib.funsor 中第一个也是最重要的概念是新的一对原语 pyro.to_funsorpyro.to_data

这些是 effectful 版本的 funsor.to_funsorfunsor.to_data,即其行为可以被 Pyro 的代数效应处理器库截获、控制或用来触发副作用的版本。在深入研究 pyro.contrib.funsor 中有副作用的版本之前,我们先简要回顾一下这两个底层函数。

正如其名称所示,to_funsor 接受非 funsor.Funsor 对象作为输入,并尝试将它们转换为 Funsor 项。例如,在一个 Python 数字上调用 funsor.to_funsor 会将其转换为一个 funsor.terms.Number 对象

[3]:
funsor_one = funsor.to_funsor(float(1))
print(funsor_one, type(funsor_one))

funsor_two = funsor.to_funsor(torch.tensor(2.))
print(funsor_two, type(funsor_two))
1.0 <class 'funsor.terms.Number'>
tensor(2.) <class 'funsor.tensor.Tensor'>

类似地,在一个原子 funsor.Funsor 上调用 funsor.to_data 会将其转换为一个常规 Python 对象,例如 float 或一个 torch.Tensor

[4]:
data_one = funsor.to_data(funsor.terms.Number(float(1), 'real'))
print(data_one, type(data_one))

data_two = funsor.to_data(funsor.Tensor(torch.tensor(2.), OrderedDict(), 'real'))
print(data_two, type(data_two))
1.0 <class 'float'>
tensor(2.) <class 'torch.Tensor'>

在许多情况下,需要提供输出类型才能唯一地将数据转换为 funsor.Funsor。这也意味着,严格来说,funsor.to_funsorfunsor.to_data 不是互逆的。例如,funsor.to_funsor 会自动将 Python 字符串转换为 funsor.Variable,但只有在提供了输出 funsor.domains.Domain 时才会这样做,它用作变量的类型

[5]:
var_x = funsor.to_funsor("x", output=funsor.Reals[2])
print(var_x, var_x.inputs, var_x.output)
x OrderedDict([('x', Reals[2])]) Reals[2]

然而,在没有额外输入类型信息的情况下,通常无法唯一地将对象转换为 Funsor 表达式或从 Funsor 表达式转换回对象,如下面的 torch.Tensor 示例所示,它可以被转换为 funsor.Tensor 通过多种方式。

为了解决这种歧义,我们需要向 to_funsorto_data 提供类型信息,描述如何转换位置维度与无序的命名 Funsor 维度之间。这些信息采用将批处理维度映射到维度名称或反之的字典形式。

这些映射的一个关键属性是它们遵循以下约定:维度索引指的是 batch dimensions (批处理维度),或者不包含在 output shape (输出形状) 中的维度,它被视为指代底层 PyTorch 张量形状的最右侧部分,如下面的示例所示。

[6]:
ambiguous_tensor = torch.zeros((3, 1, 2))
print("Ambiguous tensor: shape = {}".format(ambiguous_tensor.shape))

# case 1: treat all dimensions as output/event dimensions
funsor1 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[3, 1, 2])
print("Case 1: inputs = {}, output = {}".format(funsor1.inputs, funsor1.output))

# case 2: treat the leftmost dimension as a batch dimension
# note that dimension -1 in dim_to_name here refers to the rightmost *batch dimension*,
# i.e. dimension -3 of ambiguous_tensor, the rightmost dimension not included in the output shape.
funsor2 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[1, 2], dim_to_name={-1: "a"})
print("Case 2: inputs = {}, output = {}".format(funsor2.inputs, funsor2.output))

# case 3: treat the leftmost 2 dimensions as batch dimensions; empty batch dimensions are ignored
# note that dimensions -1 and -2 in dim_to_name here refer to the rightmost *batch dimensions*,
# i.e. dimensions -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.
funsor3 = funsor.to_funsor(ambiguous_tensor, output=funsor.Reals[2], dim_to_name={-1: "b", -2: "a"})
print("Case 3: inputs = {}, output = {}".format(funsor3.inputs, funsor3.output))

# case 4: treat all dimensions as batch dimensions; empty batch dimensions are ignored
# note that dimensions -1, -2 and -3 in dim_to_name here refer to the rightmost *batch dimensions*,
# i.e. dimensions -1, -2 and -3 of ambiguous_tensor, the rightmost dimensions not included in the output shape.
funsor4 = funsor.to_funsor(ambiguous_tensor, output=funsor.Real, dim_to_name={-1: "c", -2: "b", -3: "a"})
print("Case 4: inputs = {}, output = {}".format(funsor4.inputs, funsor4.output))
Ambiguous tensor: shape = torch.Size([3, 1, 2])
Case 1: inputs = OrderedDict(), output = Reals[3,1,2]
Case 2: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[1,2]
Case 3: inputs = OrderedDict([('a', Bint[3, ])]), output = Reals[2]
Case 4: inputs = OrderedDict([('a', Bint[3, ]), ('c', Bint[2, ])]), output = Real

to_data 也存在类似的歧义:一个 funsor.Funsorinputs 是任意顺序的,数据中的空维度会被压缩掉,因此必须提供从名称到批处理维度的映射以确保唯一转换

[7]:
ambiguous_funsor = funsor.Tensor(torch.zeros((3, 2)), OrderedDict(a=funsor.Bint[3], b=funsor.Bint[2]), 'real')
print("Ambiguous funsor: inputs = {}, shape = {}".format(ambiguous_funsor.inputs, ambiguous_funsor.output))

# case 1: the simplest version
tensor1 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -2, "b": -1})
print("Case 1: shape = {}".format(tensor1.shape))

# case 2: an empty dimension between a and b
tensor2 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -3, "b": -1})
print("Case 2: shape = {}".format(tensor2.shape))

# case 3: permuting the input dimensions
tensor3 = funsor.to_data(ambiguous_funsor, name_to_dim={"a": -1, "b": -2})
print("Case 3: shape = {}".format(tensor3.shape))
Ambiguous funsor: inputs = OrderedDict([('a', Bint[3, ]), ('b', Bint[2, ])]), shape = Real
Case 1: shape = torch.Size([3, 2])
Case 2: shape = torch.Size([3, 1, 2])
Case 3: shape = torch.Size([2, 3])

高效地维护和更新这些信息随着转换次数的增加变得繁琐且容易出错。幸运的是,它可以被完全自动化。考虑下面的示例

[8]:
name_to_dim = OrderedDict()

funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')
name_to_dim.update({"x": -1})
tensor_x = funsor.to_data(funsor_x, name_to_dim=name_to_dim)
print(name_to_dim, funsor_x.inputs, tensor_x.shape)

funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')
name_to_dim.update({"y": -2})
tensor_y = funsor.to_data(funsor_y, name_to_dim=name_to_dim)
print(name_to_dim, funsor_y.inputs, tensor_y.shape)

funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')
name_to_dim.update({"z": -3})
tensor_z = funsor.to_data(funsor_z, name_to_dim=name_to_dim)
print(name_to_dim, funsor_z.inputs, tensor_z.shape)
OrderedDict([('x', -1)]) OrderedDict([('x', Bint[2, ])]) torch.Size([2])
OrderedDict([('x', -1), ('y', -2)]) OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2])
OrderedDict([('x', -1), ('y', -2), ('z', -3)]) OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1])

这正是由 pyro.to_funsorpyro.to_data 提供的功能,通过在前面的示例中使用它们并移除手动更新,我们可以看到这一点。我们还必须用 handlers.named 效应处理器包装函数,以确保维度字典不会在函数体之外持续存在。

[9]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    funsor_x = funsor.Tensor(torch.ones((2,)), OrderedDict(x=funsor.Bint[2]), 'real')
    tensor_x = pyro.to_data(funsor_x)
    print(funsor_x.inputs, tensor_x.shape)

    funsor_y = funsor.Tensor(torch.ones((3, 2)), OrderedDict(y=funsor.Bint[3], x=funsor.Bint[2]), 'real')
    tensor_y = pyro.to_data(funsor_y)
    print(funsor_y.inputs, tensor_y.shape)

    funsor_z = funsor.Tensor(torch.ones((2, 3)), OrderedDict(z=funsor.Bint[2], y=funsor.Bint[3]), 'real')
    tensor_z = pyro.to_data(funsor_z)
    print(funsor_z.inputs, tensor_z.shape)
OrderedDict([('x', Bint[2, ])]) torch.Size([2, 1, 1, 1, 1])
OrderedDict([('y', Bint[3, ]), ('x', Bint[2, ])]) torch.Size([3, 2, 1, 1, 1, 1])
OrderedDict([('z', Bint[2, ]), ('y', Bint[3, ])]) torch.Size([2, 3, 1, 1, 1, 1, 1])

关键的是,pyro.to_funsorpyro.to_data 使用并更新相同的名称和维度之间的双向映射,这使得它们可以直观地组合使用。一个典型的使用模式,也是 pyro.contrib.funsor 在其推断算法实现中大量使用的模式,是直接创建一个带有新命名维度的 funsor.Funsor 项,并在其上调用 pyro.to_data,执行一些 PyTorch 计算,然后在结果上调用 pyro.to_funsor

[10]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    probs = funsor.Tensor(torch.tensor([0.5, 0.4, 0.7]), OrderedDict(batch=funsor.Bint[3]))
    print(type(probs), probs.inputs, probs.output)

    x = funsor.Tensor(torch.tensor([0., 1., 0., 1.]), OrderedDict(x=funsor.Bint[4]))
    print(type(x), x.inputs, x.output)

    dx = dist.Bernoulli(pyro.to_data(probs))
    print(type(dx), dx.shape())

    px = pyro.to_funsor(dx.log_prob(pyro.to_data(x)), output=funsor.Real)
    print(type(px), px.inputs, px.output)
<class 'funsor.tensor.Tensor'> OrderedDict([('batch', Bint[3, ])]) Real
<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ])]) Real
<class 'pyro.distributions.torch.Bernoulli'> torch.Size([3, 1, 1, 1, 1])
<class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[4, ]), ('batch', Bint[3, ])]) Real

pyro.to_funsorpyro.to_data 将它们名称到维度映射中的键视为对输入批处理形状的引用,但将值视为对全局一致的名称-维度映射的引用。这对于涉及 PyTorch 和 Funsor 操作混合的复杂计算可能很有用。

[11]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    x = pyro.to_funsor(torch.tensor([0., 1.]), funsor.Real, dim_to_name={-1: "x"})
    print("x: ", type(x), x.inputs, x.output)

    px = pyro.to_funsor(torch.ones(2, 3), funsor.Real, dim_to_name={-2: "x", -1: "y"})
    print("px: ", type(px), px.inputs, px.output)
x:  <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ])]) Real
px:  <class 'funsor.tensor.Tensor'> OrderedDict([('x', Bint[2, ]), ('y', Bint[3, ])]) Real

处理大量变量:(重新)引入 pyro.markov

到目前为止,一切顺利。然而,如果不同命名维度的数量持续增加怎么办?我们面临两个问题:首先,重复使用固定数量的可用位置维度(在 PyTorch 中是 25 个),其次,计算与变量数量无关的时间复杂度的形状信息。

针对此问题的完全通用的自动化解决方案需要与 Python 或 PyTorch 进行更深度的集成。相反,作为一种中间解决方案,我们引入了 pyro.contrib.funsor 中的第二个关键概念:pyro.markov 注解,这是一种表示某些变量生命周期的方式。pyro.markov 已经是 Pyro 的一部分(参见枚举教程),但 pyro.contrib.funsor 中的实现是全新的。

pyro.markov 设计的主要限制是向后兼容性:为了使 pyro.contrib.funsor 与大量的现有 Pyro 模型兼容,新实现必须尽可能地匹配 Pyro 现有枚举机制的形状语义。

[12]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    for i in pyro.markov(range(10)):
        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i): funsor.Bint[2]})))
        print("Shape of x[{}]: ".format(str(i)), x.shape)
Shape of x[0]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[1]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[2]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[3]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[4]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[5]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[6]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[7]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[8]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[9]:  torch.Size([2, 1, 1, 1, 1, 1])

pyro.markov 是一种多功能的语法元素,可以用作上下文管理器、装饰器或迭代器。重要的是要理解 pyro.markov 目前唯一的功能是跟踪变量使用情况,而不是直接向推断算法指示条件独立性属性,因此只需要添加足够的注解来确保张量具有正确的形状,而不是尝试手动编码尽可能多的依赖信息。

pyro.markov 接受一个附加参数 history,该参数确定在给定的 pyro.to_funsor/pyro.to_data 调用中构建名称和维度之间的映射时要考虑多少个先前的 pyro.markov 上下文。

[13]:
with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    for i in pyro.markov(range(10), history=2):
        x = pyro.to_data(funsor.Tensor(torch.tensor([0., 1.]), OrderedDict({"x{}".format(i): funsor.Bint[2]})))
        print("Shape of x[{}]: ".format(str(i)), x.shape)
Shape of x[0]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[1]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[2]:  torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[3]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[4]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[5]:  torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[6]:  torch.Size([2, 1, 1, 1, 1])
Shape of x[7]:  torch.Size([2, 1, 1, 1, 1, 1])
Shape of x[8]:  torch.Size([2, 1, 1, 1, 1, 1, 1])
Shape of x[9]:  torch.Size([2, 1, 1, 1, 1])

超出枚举范围的使用场景:全局维度和可见维度

全局维度

有时,让维度和变量忽略 pyro.markov 程序的结构并在任意深度嵌套的 markovnamed 上下文中保持活跃,这会很有用。例如,假设我们想从 Pyro 模型的联合分布中抽取一批样本。为此,我们通过 dim_type 关键字参数向 pyro.to_data 表明某个维度应被视为“全局”(DimType.GLOBAL)。

[14]:
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():
    funsor_particle_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))
    tensor_particle_ids = pyro.to_data(funsor_particle_ids, dim_type=DimType.GLOBAL)
    print("New global dimension: ", funsor_particle_ids.inputs, tensor_particle_ids.shape)
New global dimension:  OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])

pyro.markov 负责自动管理局部维度这项繁重的工作,但由于全局维度忽略了这种结构,它们必须手动释放,否则它们会一直持续到最后一个活跃的效应处理器退出,就像 Python 中的全局变量会持续到程序执行结束一样。

[15]:
from pyro.contrib.funsor.handlers.runtime import _DIM_STACK, DimType

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    funsor_plate1_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate1=funsor.Bint[10]))
    tensor_plate1_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)
    print("New global dimension: ", funsor_plate1_ids.inputs, tensor_plate1_ids.shape)

    funsor_plate2_ids = funsor.Tensor(torch.arange(9), OrderedDict(plate2=funsor.Bint[9]))
    tensor_plate2_ids = pyro.to_data(funsor_plate2_ids, dim_type=DimType.GLOBAL)
    print("Another new global dimension: ", funsor_plate2_ids.inputs, tensor_plate2_ids.shape)

    del _DIM_STACK.global_frame["plate1"]

    funsor_plate3_ids = funsor.Tensor(torch.arange(10), OrderedDict(plate3=funsor.Bint[10]))
    tensor_plate3_ids = pyro.to_data(funsor_plate1_ids, dim_type=DimType.GLOBAL)
    print("A third new global dimension after recycling: ", funsor_plate3_ids.inputs, tensor_plate3_ids.shape)
New global dimension:  OrderedDict([('plate1', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])
Another new global dimension:  OrderedDict([('plate2', Bint[9, ])]) torch.Size([9, 1, 1, 1, 1, 1])
A third new global dimension after recycling:  OrderedDict([('plate3', Bint[10, ])]) torch.Size([10, 1, 1, 1, 1])

直接执行这种释放通常是不必要的,我们包含这种交互主要是为了阐明 pyro.contrib.funsor 的内部机制。相反,引入全局维度的效应处理器,例如 pyro.plate,可以继承自 GlobalNamedMessenger 效应处理器,该处理器在进入和退出时通用地释放全局维度。我们将在下一个教程中看到这方面的示例。

可见维度

我们也可能希望保留数据张量形状的含义。为此,我们向 pyro.to_data 表明某个维度不仅应被视为全局的,还应被视为“可见的”(DimTypes.VISIBLE)。默认情况下,最右边的 4 个批处理维度被保留为“可见”维度,但这可以通过设置全局状态对象 _DIM_STACKfirst_available_dim 属性来改变。

接触过 pyro.infer.TraceEnum_ELBOmax_plate_nesting 参数的用户已经熟悉这种区别了。

[16]:
prev_first_available_dim = _DIM_STACK.set_first_available_dim(-2)

with pyroapi.pyro_backend("contrib.funsor"), handlers.named():

    funsor_local_ids = funsor.Tensor(torch.arange(9), OrderedDict(k=funsor.Bint[9]))
    tensor_local_ids = pyro.to_data(funsor_local_ids, dim_type=DimType.LOCAL)
    print("Tensor with new local dimension: ", funsor_local_ids.inputs, tensor_local_ids.shape)

    funsor_global_ids = funsor.Tensor(torch.arange(10), OrderedDict(n=funsor.Bint[10]))
    tensor_global_ids = pyro.to_data(funsor_global_ids, dim_type=DimType.GLOBAL)
    print("Tensor with new global dimension: ", funsor_global_ids.inputs, tensor_global_ids.shape)

    funsor_data_ids = funsor.Tensor(torch.arange(11), OrderedDict(m=funsor.Bint[11]))
    tensor_data_ids = pyro.to_data(funsor_data_ids, dim_type=DimType.VISIBLE)
    print("Tensor with new visible dimension: ", funsor_data_ids.inputs, tensor_data_ids.shape)

# we also need to reset the first_available_dim after we're done
_DIM_STACK.set_first_available_dim(prev_first_available_dim)
Tensor with new local dimension:  OrderedDict([('k', Bint[9, ])]) torch.Size([9, 1])
Tensor with new global dimension:  OrderedDict([('n', Bint[10, ])]) torch.Size([10, 1, 1])
Tensor with new visible dimension:  OrderedDict([('m', Bint[11, ])]) torch.Size([11])
[16]:
-5

可见维度也是全局的,因此必须手动释放,否则它们会持续到最后一个效应处理器退出,如同前面的示例一样。您现在可能在想,Funsor 的维度名称的行为有点像 Python 变量,具有跨表达式的作用域和持久含义;确实如此,这一观察结果是 Funsor 设计背后的关键洞察。

幸运的是,直接与维度分配器交互几乎总是不必要的,正如前一节所述,我们在此包含它仅为了阐明 pyro.contrib.funsor 的内部工作原理;相反,像 pyro.handlers.enum 这样的效应处理器,它们可能引入与可见维度冲突的非可见维度,应该继承自基础 pyro.contrib.funsor.handlers.named_messenger.NamedMessenger 效应处理器。

然而,对维度分配器的内部工作原理建立一些直觉,将使您更容易使用 contrib.funsor 中的新原语来构建强大的新的自定义推断引擎。我们将在下一个教程中看到一个这样的推断引擎示例。