离散隐变量推断

本教程介绍了 Pyro 用于离散隐变量模型的枚举策略。本教程假设读者已熟悉 张量形状指南

摘要

  • Pyro 实现了离散隐变量的自动枚举。

  • 该策略可以单独使用,也可以在 SVI(通过 TraceEnum_ELBO)、HMC 或 NUTS 内部使用。

  • 独立的 infer_discrete 可以生成样本或 MAP 估计。

  • 使用注解 infer={"enumerate": "parallel"} 标记采样点以触发枚举。

  • 如果采样点决定了下游结构,则改用 {"enumerate": "sequential"}

  • 编写模型时,允许左侧任意深度的批处理,例如使用广播 (broadcasting)。

  • 推断成本与树宽呈指数关系,因此请尝试编写树宽较小的模型。

  • 如果遇到问题,请在 forum.pyro.ai 上寻求帮助!

目录

[1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.infer.autoguide import AutoNormal
from pyro.ops.indexing import Vindex

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')
pyro.set_rng_seed(0)

概述

Pyro 的枚举策略 (Obermeyer 等人,2019) 涵盖了流行的算法,包括变量消除 (variable elimination)、精确消息传递 (exact message passing)、前向-滤波-后向-采样 (forward-filter-backward-sample)、inside-out、Baum-Welch 以及许多其他特殊情况算法。除了枚举,Pyro 还实现了许多推断策略,包括变分推断 (SVI) 和蒙特卡洛方法 (HMCNUTS)。枚举可以作为独立的策略通过 infer_discrete 使用,也可以作为其他策略的组成部分。因此,枚举使得 Pyro 可以在 HMC 和 SVI 模型中边缘化离散隐变量,并在 SVI guides 中使用离散变量的变分枚举。

枚举机制

枚举的核心思想是将离散的 pyro.sample 语句解释为完全枚举而不是随机采样。然后其他推断算法可以对枚举值求和。例如,在标准的“采样”解释下,采样语句可能返回一个标量形状的张量(我们将使用简单的模型和 guide 进行说明)

[2]:
def model():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print(f"model z = {z}")

def guide():
    z = pyro.sample("z", dist.Categorical(torch.ones(5)))
    print(f"guide z = {z}")

elbo = Trace_ELBO()
elbo.loss(model, guide);
guide z = 4
model z = 4

然而,在枚举解释下,同一个采样点将根据其分布的 .enumerate_support() 方法返回一组完全枚举的值。

[3]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "parallel"));
guide z = tensor([0, 1, 2, 3, 4])
model z = tensor([0, 1, 2, 3, 4])

请注意,我们使用了“并行”枚举,沿新的张量维度进行枚举。这种方法成本低廉,并允许 Pyro 并行计算,但要求下游程序结构避免基于 z 的值进行分支。为了支持动态程序结构,您可以改用“顺序”枚举,它会针对每个样本值运行整个模型-guide 对一次,但这需要多次运行模型。

[4]:
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, config_enumerate(guide, "sequential"));
guide z = 4
model z = 4
guide z = 3
model z = 3
guide z = 2
model z = 2
guide z = 1
model z = 1
guide z = 0
model z = 0

并行枚举比顺序枚举成本低廉但更复杂,因此本教程的其余部分将重点介绍并行变体。请注意,这两种形式可以交错使用。

多个隐变量

我们刚刚看到,单个离散采样点可以通过非标准解释进行枚举。具有单个离散隐变量的模型是混合模型。具有多个离散隐变量的模型可能更复杂,包括 HMM、CRF、DBN 和其他结构化模型。在具有多个离散隐变量的模型中,Pyro 在不同的张量维度中枚举每个变量(从右侧开始计数;参见 张量形状指南)。这使得 Pyro 能够确定变量之间的依赖图,然后使用变量消除算法执行廉价的精确推断。

为了理解枚举维度分配,考虑以下模型,在此我们将变量从模型中折叠出去,而不是在 guide 中枚举它们。

[5]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(3, 3).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(p[0]))
    y = pyro.sample("y", dist.Categorical(p[x]))
    z = pyro.sample("z", dist.Categorical(p[y]))
    print(f"  model x.shape = {x.shape}")
    print(f"  model y.shape = {y.shape}")
    print(f"  model z.shape = {z.shape}")
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
print("Sampling:")
model()
print("Enumerated Inference:")
elbo = TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, guide);
Sampling:
  model x.shape = torch.Size([])
  model y.shape = torch.Size([])
  model z.shape = torch.Size([])
Enumerated Inference:
  model x.shape = torch.Size([3])
  model y.shape = torch.Size([3, 1])
  model z.shape = torch.Size([3, 1, 1])

检查离散隐状态

虽然 SVI 中的枚举允许快速学习诸如上面的 p 之类的参数,但它不提供对离散隐变量(如上面的 x,y,z)的预测值的访问。我们可以使用独立的 infer_discrete handler 来访问这些值。在这种情况下,guide 很简单,因此我们可以直接将模型包装在 infer_discrete 中。我们需要传递一个 first_available_dim 参数来告诉 infer_discrete 哪些维度可用于枚举;这与 TraceEnum_ELBOmax_plate_nesting 参数相关,通过:

first_available_dim = -1 - max_plate_nesting
[6]:
serving_model = infer_discrete(model, first_available_dim=-1)
x, y, z = serving_model()  # takes the same args as model(), here no args
print(f"x = {x}")
print(f"y = {y}")
print(f"z = {z}")
  model x.shape = torch.Size([3])
  model y.shape = torch.Size([3, 1])
  model z.shape = torch.Size([3, 1, 1])
  model x.shape = torch.Size([])
  model y.shape = torch.Size([])
  model z.shape = torch.Size([])
x = 2
y = 1
z = 0

请注意,infer_discrete 在底层运行模型两次:首先是前向滤波模式,其中采样点被枚举;然后是回放-后向-采样模式,其中采样点被采样。infer_discrete 也可以通过传递 temperature=0 执行 MAP 推断。请注意,尽管 infer_discrete 生成正确的后验样本,但目前它不生成正确的 logprobs,并且不应在其他基于梯度的推断算法中使用。

使用枚举变量进行索引

使用一个或多个枚举变量对张量元素进行 高级索引 可能很棘手。在 Pyro 模型中尤其如此,因为您的模型的索引操作需要在多种解释下工作:从模型采样(生成数据)和枚举推断过程中。例如,假设一个 plate 随机变量 z 依赖于两个不同的随机变量:

p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(),
               constraint=constraints.simplex)
x = pyro.sample("x", dist.Categorical(torch.ones(4)))
y = pyro.sample("y", dist.Categorical(torch.ones(3)))
with pyro.plate("z_plate", 5):
    p_xy = p[..., x, y, :]  # Not compatible with enumeration!
    z = pyro.sample("z", dist.Categorical(p_xy)

由于高级索引的语义,表达式 p[..., x, y, :] 在没有枚举时可以正常工作,但在 xy 被枚举时是错误的。Pyro 提供了一种正确索引的简单方法,但首先让我们看看如何在不使用 Pyro 的情况下正确使用 PyTorch 的高级索引。

# Compatible with enumeration, but not recommended:
p_xy = p[torch.arange(5, device=p.device).reshape(5, 1),
         x.unsqueeze(-1),
         y.unsqueeze(-1),
         torch.arange(2, device=p.device)]

Pyro 提供了一个辅助工具 Vindex()[],用于使用与枚举兼容的高级索引语义,而不是标准的 PyTorch/NumPy 语义。(请注意 Vindex 名称和语义遵循 Numpy 改进提案 NEP 21)。Vindex()[] 使 .__getitem__() 运算符像其他熟悉的运算符 +* 等一样进行广播。使用 Vindex()[],我们可以像 xy 是数字(即未枚举)时一样编写相同的表达式:

# Recommended syntax compatible with enumeration:
p_xy = Vindex(p)[..., x, y, :]

这是一个完整的例子:

[7]:
@config_enumerate
def model():
    p = pyro.param("p", torch.randn(5, 4, 3, 2).exp(), constraint=constraints.simplex)
    x = pyro.sample("x", dist.Categorical(torch.ones(4)))
    y = pyro.sample("y", dist.Categorical(torch.ones(3)))
    with pyro.plate("z_plate", 5):
        p_xy = Vindex(p)[..., x, y, :]
        z = pyro.sample("z", dist.Categorical(p_xy))
    print(f"     p.shape = {p.shape}")
    print(f"     x.shape = {x.shape}")
    print(f"     y.shape = {y.shape}")
    print(f"  p_xy.shape = {p_xy.shape}")
    print(f"     z.shape = {z.shape}")
    return x, y, z

def guide():
    pass

pyro.clear_param_store()
print("Sampling:")
model()
print("Enumerated Inference:")
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide);
Sampling:
     p.shape = torch.Size([5, 4, 3, 2])
     x.shape = torch.Size([])
     y.shape = torch.Size([])
  p_xy.shape = torch.Size([5, 2])
     z.shape = torch.Size([5])
Enumerated Inference:
     p.shape = torch.Size([5, 4, 3, 2])
     x.shape = torch.Size([4, 1])
     y.shape = torch.Size([3, 1, 1])
  p_xy.shape = torch.Size([3, 4, 5, 2])
     z.shape = torch.Size([2, 1, 1, 1])

在 plate 中进行枚举时(如下一节所述),Vindex 也可以与通过 with pyro.plate(...) as i 捕获 plate 索引一起使用,以便对批量维度进行索引。这里有一个例子,由于 Dirichlet 分布,事件维度是非平凡的。

[8]:
@config_enumerate
def model():
    data_plate = pyro.plate("data_plate", 6, dim=-1)
    feature_plate = pyro.plate("feature_plate", 5, dim=-2)
    component_plate = pyro.plate("component_plate", 4, dim=-1)
    with feature_plate:
        with component_plate:
            p = pyro.sample("p", dist.Dirichlet(torch.ones(3)))
    with data_plate:
        c = pyro.sample("c", dist.Categorical(torch.ones(4)))
        with feature_plate as vdx:                # Capture plate index.
            pc = Vindex(p)[vdx[..., None], c, :]  # Reshape it and use in Vindex.
            x = pyro.sample("x", dist.Categorical(pc),
                            obs=torch.zeros(5, 6, dtype=torch.long))
    print(f"    p.shape = {p.shape}")
    print(f"    c.shape = {c.shape}")
    print(f"  vdx.shape = {vdx.shape}")
    print(f"    pc.shape = {pc.shape}")
    print(f"    x.shape = {x.shape}")

def guide():
    feature_plate = pyro.plate("feature_plate", 5, dim=-2)
    component_plate = pyro.plate("component_plate", 4, dim=-1)
    with feature_plate, component_plate:
        pyro.sample("p", dist.Dirichlet(torch.ones(3)))

pyro.clear_param_store()
print("Sampling:")
model()
print("Enumerated Inference:")
elbo = TraceEnum_ELBO(max_plate_nesting=2)
elbo.loss(model, guide);
Sampling:
    p.shape = torch.Size([5, 4, 3])
    c.shape = torch.Size([6])
  vdx.shape = torch.Size([5])
    pc.shape = torch.Size([5, 6, 3])
    x.shape = torch.Size([5, 6])
Enumerated Inference:
    p.shape = torch.Size([5, 4, 3])
    c.shape = torch.Size([4, 1, 1])
  vdx.shape = torch.Size([5])
    pc.shape = torch.Size([4, 5, 1, 3])
    x.shape = torch.Size([5, 6])

Plates 和枚举

Pyro 的 plates 表达了随机变量之间的条件独立性。Pyro 的枚举策略可以利用 plates 将枚举笛卡尔积的高成本(与 plate 大小呈指数关系)降低到在锁定步调中枚举条件独立随机变量的低成本(与 plate 大小呈线性关系)。这对于例如小批量数据尤其重要。

为了说明这一点,考虑一个具有共享方差和不同均值的高斯混合模型。

[9]:
@config_enumerate
def model(data, num_components=3):
    print(f"  Running model with {len(data)} data points")
    p = pyro.sample("p", dist.Dirichlet(0.5 * torch.ones(num_components)))
    scale = pyro.sample("scale", dist.LogNormal(0, num_components))
    with pyro.plate("components", num_components):
        loc = pyro.sample("loc", dist.Normal(0, 10))
    with pyro.plate("data", len(data)):
        x = pyro.sample("x", dist.Categorical(p))
        print("    x.shape = {}".format(x.shape))
        pyro.sample("obs", dist.Normal(loc[x], scale), obs=data)
        print("    dist.Normal(loc[x], scale).batch_shape = {}".format(
            dist.Normal(loc[x], scale).batch_shape))

guide = AutoNormal(poutine.block(model, hide=["x", "data"]))

data = torch.randn(10)

pyro.clear_param_store()
print("Sampling:")
model(data)
print("Enumerated Inference:")
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model, guide, data);
Sampling:
  Running model with 10 data points
    x.shape = torch.Size([10])
    dist.Normal(loc[x], scale).batch_shape = torch.Size([10])
Enumerated Inference:
  Running model with 10 data points
    x.shape = torch.Size([10])
    dist.Normal(loc[x], scale).batch_shape = torch.Size([10])
  Running model with 10 data points
    x.shape = torch.Size([3, 1])
    dist.Normal(loc[x], scale).batch_shape = torch.Size([3, 1])

观察到,在推断过程中,模型会运行两次:首先由 AutoNormal 追踪采样点,其次由 elbo 计算损失。在第一次运行中,x 具有每个数据一个样本的标准解释,因此形状为 (10,)。在第二次运行中,枚举可以使用相同的三个值 (3,1) 应用于所有数据点,并依赖广播来处理任何依赖于数据的相关采样或观察点。例如,在 pyro.sample("obs",...) 语句中,分布的形状为 (3,1),数据的形状为 (10,),而广播后的对数概率张量的形状为 (3,10)

关于混合模型中枚举的更深入讨论,请参阅 高斯混合模型指南隐马尔可夫模型示例

Plates 之间的依赖关系

在向量化 plates 中枚举的计算节省带来了模型依赖结构的限制(如 (Obermeyer 等人,2019) 中所述)。这些限制是在通常的条件独立性限制之外的。枚举限制由 TraceEnum_ELBO 检查,如果违反将导致错误(然而,通常的条件独立性限制 Pyro 无法普遍验证)。为了完整起见,我们列出所有三个限制:

限制 1:条件独立性

plate 内的变量不能相互依赖(沿 plate 维度)。这适用于任何变量,无论是否被枚举。这适用于顺序 plates 和向量化 plates。例如,以下模型是无效的:

def invalid_model():
    x = 0
    for i in pyro.plate("invalid", 10):
        x = pyro.sample(f"x_{i}", dist.Normal(x, 1.))

限制 2:无下游耦合

向量化 plate 外部的任何变量都不能依赖于该 plate 内部的枚举变量。这将违反 Pyro 的指数加速假设。例如,以下模型是无效的:

@config_enumerate
def invalid_model(data):
    with pyro.plate("plate", 10):  # <--- invalid vectorized plate
        x = pyro.sample("x", dist.Bernoulli(0.5))
    assert x.shape == (10,)
    pyro.sample("obs", dist.Normal(x.sum(), 1.), data)

为了解决此限制,可以将向量化 plate 转换为顺序 plate:

@config_enumerate
def valid_model(data):
    x = []
    for i in pyro.plate("plate", 10):  # <--- valid sequential plate
        x.append(pyro.sample(f"x_{i}", dist.Bernoulli(0.5)))
    assert len(x) == 10
    pyro.sample("obs", dist.Normal(sum(x), 1.), data)

限制 3:每个 plate 只有一个离开路径

最后一个限制很微妙,但这是实现 Pyro 指数加速所必需的。

对于任何枚举变量 xx 依赖的所有枚举变量集合必须在其向量化 plate 嵌套中呈线性可排序。

此要求仅适用于存在至少两个 plates 且至少有三个变量位于不同 plate 上下文的情况。最简单的反例是玻尔兹曼机 (Boltzmann machine):

@config_enumerate
def invalid_model(data):
    plate_1 = pyro.plate("plate_1", 10, dim=-1)  # vectorized
    plate_2 = pyro.plate("plate_2", 10, dim=-2)  # vectorized
    with plate_1:
        x = pyro.sample("x", dist.Bernoulli(0.5))
    with plate_2:
        y = pyro.sample("y", dist.Bernoulli(0.5))
    with plate_1, plate2:
        z = pyro.sample("z", dist.Bernoulli((1. + x + y) / 4.))
        ...

这里我们看到变量 z 依赖于变量 x(它在 plate_1 中而不在 plate_2 中),并且依赖于变量 y(它在 plate_2 中而不在 plate_1 中)。此模型无效,因为无法对 xy 进行线性排序,使得一个的 plate 嵌套级别小于另一个。

为了解决此限制,可以将其中一个 plate 转换为顺序 plate:

@config_enumerate
def valid_model(data):
    plate_1 = pyro.plate("plate_1", 10, dim=-1)  # vectorized
    plate_2 = pyro.plate("plate_2", 10)          # sequential
    with plate_1:
        x = pyro.sample("x", dist.Bernoulli(0.5))
    for i in plate_2:
        y = pyro.sample(f"y_{i}", dist.Bernoulli(0.5))
        with plate_1:
            z = pyro.sample(f"z_{i}", dist.Bernoulli((1. + x + y) / 4.))
            ...

但请注意,这会增加计算复杂度,计算复杂度可能与顺序 plate 的大小呈指数关系。

时间序列示例

考虑一个具有隐状态 \(x_t\) 和观测值 \(y_t\) 的离散 HMM。假设我们想学习其转移概率和发射概率。

[10]:
data_dim = 4
num_steps = 10
data = dist.Categorical(torch.ones(num_steps, data_dim)).sample()

def hmm_model(data, data_dim, hidden_dim=10):
    print(f"Running for {len(data)} time steps")
    # Sample global matrices wrt a Jeffreys prior.
    with pyro.plate("hidden_state", hidden_dim):
        transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
        emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))

    x = 0  # initial state
    for t, y in enumerate(data):
        x = pyro.sample(f"x_{t}", dist.Categorical(transition[x]),
                        infer={"enumerate": "parallel"})
        pyro.sample(f"  y_{t}", dist.Categorical(emission[x]), obs=y)
        print(f"  x_{t}.shape = {x.shape}")

我们可以使用带有 autoguide 的 SVI 来学习全局参数。

[11]:
hmm_guide = AutoNormal(poutine.block(hmm_model, expose=["transition", "emission"]))

pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);
Running for 10 time steps
  x_0.shape = torch.Size([])
  x_1.shape = torch.Size([])
  x_2.shape = torch.Size([])
  x_3.shape = torch.Size([])
  x_4.shape = torch.Size([])
  x_5.shape = torch.Size([])
  x_6.shape = torch.Size([])
  x_7.shape = torch.Size([])
  x_8.shape = torch.Size([])
  x_9.shape = torch.Size([])
Running for 10 time steps
  x_0.shape = torch.Size([10, 1])
  x_1.shape = torch.Size([10, 1, 1])
  x_2.shape = torch.Size([10, 1, 1, 1])
  x_3.shape = torch.Size([10, 1, 1, 1, 1])
  x_4.shape = torch.Size([10, 1, 1, 1, 1, 1])
  x_5.shape = torch.Size([10, 1, 1, 1, 1, 1, 1])
  x_6.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1])
  x_7.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1])
  x_8.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1])
  x_9.shape = torch.Size([10, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

请注意,这里模型运行了两次:首先由 AutoNormal 在没有枚举的情况下运行,以便 autoguide 可以记录所有采样点;然后由 TraceEnum_ELBO 在启用枚举的情况下第二次运行。我们在第一次运行中看到样本具有标准解释,而在第二次运行中样本具有枚举解释。

有关更复杂的示例,包括小批量处理和多个 plates,请参阅 HMM 教程

如何枚举超过 25 个变量

PyTorch 张量在 CUDA 中的维度限制为 25,在 CPU 中为 64。默认情况下,Pyro 在一个新的维度中枚举每个采样点。如果您需要更多采样点,可以使用 pyro.markov 标记您的模型,告知 Pyro 何时可以安全地回收张量维度。让我们看看它如何应用于上面的 HMM 模型。我们唯一需要做的更改是使用 pyro.markov 标记 for 循环,告知 Pyro 循环中每一步的变量仅依赖于循环外部的变量以及循环中当前步和前一步的变量。

- for t, y in enumerate(data):
+ for t, y in pyro.markov(enumerate(data)):
[12]:
def hmm_model(data, data_dim, hidden_dim=10):
    with pyro.plate("hidden_state", hidden_dim):
        transition = pyro.sample("transition", dist.Dirichlet(0.5 * torch.ones(hidden_dim)))
        emission = pyro.sample("emission", dist.Dirichlet(0.5 * torch.ones(data_dim)))

    x = 0  # initial state
    for t, y in pyro.markov(enumerate(data)):
        x = pyro.sample(f"x_{t}", dist.Categorical(transition[x]),
                        infer={"enumerate": "parallel"})
        pyro.sample(f"y_{t}", dist.Categorical(emission[x]), obs=y)
        print(f"x_{t}.shape = {x.shape}")

# We'll reuse the same guide and elbo.
elbo.loss(hmm_model, hmm_guide, data, data_dim=data_dim);
x_0.shape = torch.Size([10, 1])
x_1.shape = torch.Size([10, 1, 1])
x_2.shape = torch.Size([10, 1])
x_3.shape = torch.Size([10, 1, 1])
x_4.shape = torch.Size([10, 1])
x_5.shape = torch.Size([10, 1, 1])
x_6.shape = torch.Size([10, 1])
x_7.shape = torch.Size([10, 1, 1])
x_8.shape = torch.Size([10, 1])
x_9.shape = torch.Size([10, 1, 1])

请注意,此模型现在仅需要三个张量维度:一个用于 plate,一个用于偶数状态,一个用于奇数状态。有关更复杂的示例,请参阅 HMM 示例 中的动态贝叶斯网络模型。

[ ]: