Pyro 中的张量形状

本教程介绍 Pyro 中张量维度的组织方式。开始之前,您应熟悉 PyTorch 广播语义
学完本教程后,您可能还需要阅读关于 枚举 的内容。

您可能还会发现阅读 Eric J. Ma 的文章 Reasoning about Shapes and Probability Distributions 会有帮助。尽管这篇文章专门针对 TensorFlow Probability,但许多概念是通用的。

总结:

  • 张量通过右对齐进行广播:torch.ones(3,4,5) + torch.ones(5)

  • 分布的 .sample().shape == batch_shape + event_shape

  • 分布的 .log_prob(x).shape == batch_shape (但不包括 event_shape!)。

  • 使用 .expand() 绘制一批样本,或者依赖 plate 自动扩展。

  • 使用 my_dist.to_event(1) 将维度声明为依赖维度。

  • 使用 with pyro.plate('name', size): 将维度声明为条件独立维度。

  • 所有维度都必须被声明为依赖维度或条件独立维度。

  • 尽量支持左侧的批处理。这使得 Pyro 可以自动并行化。

    • 使用负索引,例如 x.sum(-1),而不是 x.sum(2)

    • 使用省略号表示法,例如 pixel = image[..., i, j]

    • 如果 i,j 是枚举的,使用 Vindex,例如 pixel = Vindex(image)[..., i, j]

  • 使用 pyro.plate 的自动子采样时,请确保对您的数据进行子采样

    • 或者通过捕获索引手动进行子采样:with pyro.plate(...) as i: ...

    • 或者通过 batch = pyro.subsample(data, event_dim=...) 自动进行子采样。

  • 调试时,使用 Trace.format_shapes() 检查跟踪中的所有形状。

目录

[1]:
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

分布形状:batch_shapeevent_shape

PyTorch 的 Tensor 只有一个 .shape 属性,但 Distribution 有两个具有特殊含义的形状属性:.batch_shape.event_shape。这两者结合起来定义了样本的总形状

x = d.sample()
assert x.shape == d.batch_shape + d.event_shape

.batch_shape 上的索引表示条件独立随机变量,而 .event_shape 上的索引表示依赖随机变量(即从分布中抽取一次)。由于依赖随机变量共同定义了概率,.log_prob() 方法对于每个形状为 .event_shape 的事件只产生一个数字。因此,.log_prob() 的总形状是 .batch_shape

assert d.log_prob(x).shape == d.batch_shape

请注意,Distribution.sample() 方法还接受一个 sample_shape 参数,该参数对独立同分布 (iid) 随机变量进行索引,因此

x2 = d.sample(sample_shape)
assert x2.shape == sample_shape + batch_shape + event_shape

总结来说

      |      iid     | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape

例如,单变量分布的事件形状为空(因为每个数字都是一个独立事件)。像 MultivariateNormal 这样的向量分布的 len(event_shape) == 1。像 InverseWishart 这样的矩阵分布的 len(event_shape) == 2

示例

最简单的分布形状是单个单变量分布。

[2]:
d = Bernoulli(0.5)
assert d.batch_shape == ()
assert d.event_shape == ()
x = d.sample()
assert x.shape == ()
assert d.log_prob(x).shape == ()

可以通过传入批处理参数来对分布进行批处理。

[3]:
d = Bernoulli(0.5 * torch.ones(3,4))
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

另一种对分布进行批处理的方法是使用 .expand() 方法。这仅在参数沿最左侧维度相同时有效。

[4]:
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4])
assert d.batch_shape == (3, 4)
assert d.event_shape == ()
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3, 4)

多变量分布的 .event_shape 非空。对于这些分布,.sample().log_prob(x) 的形状不同

[5]:
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3))
assert d.batch_shape == ()
assert d.event_shape == (3,)
x = d.sample()
assert x.shape == (3,)            # == batch_shape + event_shape
assert d.log_prob(x).shape == ()  # == batch_shape

重塑分布

在 Pyro 中,您可以通过调用 .to_event(n) 属性将单变量分布视为多变量分布,其中 n 是从右侧开始声明为 依赖 的批量维度数。

[6]:
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1)
assert d.batch_shape == (3,)
assert d.event_shape == (4,)
x = d.sample()
assert x.shape == (3, 4)
assert d.log_prob(x).shape == (3,)

在编写 Pyro 程序时,请记住样本的形状是 batch_shape + event_shape,而 .log_prob(x) 值的形状是 batch_shape。您需要通过使用 .to_event(n) 裁剪 batch_shape 或通过 pyro.plate 将维度声明为独立来仔细控制 batch_shape

假设存在依赖总是安全的

在 Pyro 中,我们经常将某些维度声明为依赖维度,即使它们实际上是独立的,例如:

x = pyro.sample("x", Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)

这有两个原因:首先,它使我们稍后可以轻松地换入 MultivariateNormal 分布。其次,它稍微简化了代码,因为我们不需要像在下方那样使用 plate

with pyro.plate("x_plate", 10):
    x = pyro.sample("x", Normal(0, 1))  # .expand([10]) is automatic
    assert x.shape == (10,)

这两个版本的区别在于,带 plate 的第二个版本告知 Pyro 在估计梯度时可以利用条件独立性信息,而在第一个版本中 Pyro 必须假定它们是依赖的(尽管这些正态分布实际上是条件独立的)。这类似于图模型中的 d 分离:增加边并假设变量 可能 依赖(即放宽模型类别)总是安全的,但在变量实际依赖时假定独立是不安全的(即收紧模型类别,使得真实模型落在此类别之外,如在均值场中)。实际上,Pyro 的 SVI 推断算法对 Normal 分布使用重参数化梯度估计器,因此两种梯度估计器具有相同的性能。

使用 plate 声明独立维度

Pyro 模型可以使用上下文管理器 pyro.plate 来声明某些批量维度是独立的。推断算法可以利用这种独立性,例如构建方差更低的梯度估计器,或者在线性空间而非指数空间中进行枚举。独立维度的一个例子是小批量数据上的索引:每个数据点应独立于所有其他数据点。

声明维度独立的最简单方法是通过一个简单的语句将最右边的批量维度声明为独立:

with pyro.plate("my_plate"):
    # within this context, batch dimension -1 is independent

我们建议始终提供可选的 size 参数,以帮助调试形状

with pyro.plate("my_plate", len(my_data)):
    # within this context, batch dimension -1 is independent

从 Pyro 0.2 开始,您还可以嵌套 plates,例如,如果您有像素级别的独立性

with pyro.plate("x_axis", 320):
    # within this context, batch dimension -1 is independent
    with pyro.plate("y_axis", 200):
        # within this context, batch dimensions -2 and -1 are independent

请注意,我们总是从右侧开始计数,使用负索引,如 -2, -1。

最后,如果您想混合使用 plate 来处理例如仅依赖于 x 的噪声、仅依赖于 y 的噪声以及同时依赖于两者的噪声,您可以声明多个 plate 并将它们用作可重用的上下文管理器。在这种情况下,Pyro 无法自动分配维度,因此您需要提供一个 dim 参数(同样从右侧开始计数)

x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
    # within this context, batch dimension -2 is independent
with y_axis:
    # within this context, batch dimension -3 is independent
with x_axis, y_axis:
    # within this context, batch dimensions -3 and -2 are independent

让我们仔细看看 plate 中的批量大小。

[7]:
def model1():
    a = pyro.sample("a", Normal(0, 1))
    b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))
    with pyro.plate("c_plate", 2):
        c = pyro.sample("c", Normal(torch.zeros(2), 1))
    with pyro.plate("d_plate", 3):
        d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
    assert a.shape == ()       # batch_shape == ()     event_shape == ()
    assert b.shape == (2,)     # batch_shape == ()     event_shape == (2,)
    assert c.shape == (2,)     # batch_shape == (2,)   event_shape == ()
    assert d.shape == (3,4,5)  # batch_shape == (3,)   event_shape == (4,5)

    x_axis = pyro.plate("x_axis", 3, dim=-2)
    y_axis = pyro.plate("y_axis", 2, dim=-3)
    with x_axis:
        x = pyro.sample("x", Normal(0, 1))
    with y_axis:
        y = pyro.sample("y", Normal(0, 1))
    with x_axis, y_axis:
        xy = pyro.sample("xy", Normal(0, 1))
        z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
    assert x.shape == (3, 1)        # batch_shape == (3,1)     event_shape == ()
    assert y.shape == (2, 1, 1)     # batch_shape == (2,1,1)   event_shape == ()
    assert xy.shape == (2, 3, 1)    # batch_shape == (2,3,1)   event_shape == ()
    assert z.shape == (2, 3, 1, 5)  # batch_shape == (2,3,1)   event_shape == (5,)

test_model(model1, model1, Trace_ELBO())

通过将每个采样站点的 .shapebatch_shapeevent_shape 的边界处对齐来可视化形状是很有帮助的:右侧的维度将在 .log_prob() 中求和消除,左侧的维度将保留。

batch dims | event dims
-----------+-----------
           |        a = sample("a", Normal(0, 1))
           |2       b = sample("b", Normal(zeros(2), 1)
           |                        .to_event(1))
           |        with plate("c", 2):
          2|            c = sample("c", Normal(zeros(2), 1))
           |        with plate("d", 3):
          3|4 5         d = sample("d", Normal(zeros(3,4,5), 1)
           |                       .to_event(2))
           |
           |        x_axis = plate("x", 3, dim=-2)
           |        y_axis = plate("y", 2, dim=-3)
           |        with x_axis:
        3 1|            x = sample("x", Normal(0, 1))
           |        with y_axis:
      2 1 1|            y = sample("y", Normal(0, 1))
           |        with x_axis, y_axis:
      2 3 1|            xy = sample("xy", Normal(0, 1))
      2 3 1|5           z = sample("z", Normal(0, 1).expand([5])
           |                       .to_event(1))

要自动检查程序中采样站点的形状,您可以跟踪程序并使用 Trace.format_shapes() 方法,该方法会为每个采样站点打印三种形状:分布形状(包括 site["fn"].batch_shapesite["fn"].event_shape)、值形状(site["value"].shape),如果已计算对数概率,还会打印 log_prob 形状(site["log_prob"].shape

[8]:
trace = poutine.trace(model1).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
 Param Sites:
Sample Sites:
       a dist       |
        value       |
     log_prob       |
       b dist       | 2
        value       | 2
     log_prob       |
 c_plate dist       |
        value     2 |
     log_prob       |
       c dist     2 |
        value     2 |
     log_prob     2 |
 d_plate dist       |
        value     3 |
     log_prob       |
       d dist     3 | 4 5
        value     3 | 4 5
     log_prob     3 |
  x_axis dist       |
        value     3 |
     log_prob       |
  y_axis dist       |
        value     2 |
     log_prob       |
       x dist   3 1 |
        value   3 1 |
     log_prob   3 1 |
       y dist 2 1 1 |
        value 2 1 1 |
     log_prob 2 1 1 |
      xy dist 2 3 1 |
        value 2 3 1 |
     log_prob 2 3 1 |
       z dist 2 3 1 | 5
        value 2 3 1 | 5
     log_prob 2 3 1 |

plate 内部对张量进行子采样

使用 plate 的主要用途之一是对数据进行子采样。这在 plate 内部是可能的,因为数据是条件独立的,因此,例如,一半数据的损失的期望值应该是全部数据损失期望值的一半。

要对数据进行子采样,您需要告知 Pyro 原始数据大小和子采样大小;然后 Pyro 将选择数据的随机子集并生成索引集合。

[9]:
data = torch.arange(100.)

def model2():
    mean = pyro.param("mean", torch.zeros(len(data)))
    with pyro.plate("data", len(data), subsample_size=10) as ind:
        assert len(ind) == 10    # ind is a LongTensor that indexes the subsample.
        batch = data[ind]        # Select a minibatch of data.
        mean_batch = mean[ind]   # Take care to select the relevant per-datum parameters.
        # Do stuff with batch:
        x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
        assert len(x) == 10

test_model(model2, guide=lambda: None, loss=Trace_ELBO())

使用广播实现并行枚举

Pyro 0.2 引入了并行枚举离散潜变量的能力。这可以在通过 SVI 学习后验时显著降低梯度估计器的方差。

要使用并行枚举,Pyro 需要分配可用于枚举的张量维度。为了避免与我们想要用于 plate 的其他维度冲突,我们需要声明我们将使用的最大张量维度数量的预算。此预算称为 max_plate_nesting,并且是 SVI 的一个参数(该参数只是简单地传递给 TraceEnum_ELBO)。通常 Pyro 可以自行确定此预算(它会运行 (model,guide) 对一次并记录发生的情况),但在模型结构动态变化的情况下,您可能需要手动声明 max_plate_nesting

为了理解 max_plate_nesting 以及 Pyro 如何为枚举分配维度,让我们回顾上面提到的 model1()。这次我们将绘制出三种类型的维度:左侧的枚举维度(Pyro 控制这些维度)、中间的批量维度和右侧的事件维度。

      max_plate_nesting = 3
           |<--->|
enumeration|batch|event
-----------+-----+-----
           |. . .|      a = sample("a", Normal(0, 1))
           |. . .|2     b = sample("b", Normal(zeros(2), 1)
           |     |                      .to_event(1))
           |     |      with plate("c", 2):
           |. . 2|          c = sample("c", Normal(zeros(2), 1))
           |     |      with plate("d", 3):
           |. . 3|4 5       d = sample("d", Normal(zeros(3,4,5), 1)
           |     |                     .to_event(2))
           |     |
           |     |      x_axis = plate("x", 3, dim=-2)
           |     |      y_axis = plate("y", 2, dim=-3)
           |     |      with x_axis:
           |. 3 1|          x = sample("x", Normal(0, 1))
           |     |      with y_axis:
           |2 1 1|          y = sample("y", Normal(0, 1))
           |     |      with x_axis, y_axis:
           |2 3 1|          xy = sample("xy", Normal(0, 1))
           |2 3 1|5         z = sample("z", Normal(0, 1).expand([5]))
           |     |                     .to_event(1))

请注意,超量配置 max_plate_nesting=4 是安全的,但我们不能欠量配置 max_plate_nesting=2(否则 Pyro 会出错)。让我们看看这在实践中是如何工作的。

[10]:
@config_enumerate
def model3():
    p = pyro.param("p", torch.arange(6.) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))

    a = pyro.sample("a", Categorical(torch.ones(6) / 6))
    b = pyro.sample("b", Bernoulli(p[a]))  # Note this depends on a.
    with pyro.plate("c_plate", 4):
        c = pyro.sample("c", Bernoulli(0.3))
        with pyro.plate("d_plate", 5):
            d = pyro.sample("d", Bernoulli(0.4))
            e_loc = locs[d.long()].unsqueeze(-1)
            e_scale = torch.arange(1., 8.)
            e = pyro.sample("e", Normal(e_loc, e_scale)
                            .to_event(1))  # Note this depends on d.

    #                   enumerated|batch|event dims
    assert a.shape == (         6, 1, 1   )  # Six enumerated values of the Categorical.
    assert b.shape == (      2, 1, 1, 1   )  # Two enumerated Bernoullis, unexpanded.
    assert c.shape == (   2, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert d.shape == (2, 1, 1, 1, 1, 1   )  # Only two Bernoullis, unexpanded.
    assert e.shape == (2, 1, 1, 1, 5, 4, 7)  # This is sampled and depends on d.

    assert e_loc.shape   == (2, 1, 1, 1, 1, 1, 1,)
    assert e_scale.shape == (                  7,)

test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))

让我们仔细看看这些维度。首先请注意,Pyro 在 max_plate_nesting 处从右侧开始分配枚举维度:Pyro 分配维度 -3 来枚举 a,然后分配维度 -4 来枚举 b,然后分配维度 -5 来枚举 c,最后分配维度 -6 来枚举 d。接下来请注意,样本仅在新枚举维度中具有范围(大小 > 1)。这有助于保持张量较小且计算成本较低。(请注意,log_prob 形状将被广播,使其包含枚举形状和批量形状,因此例如 trace.nodes['d']['log_prob'].shape == (2, 1, 1, 1, 5, 4)。)

我们可以绘制类似的张量维度图

     max_plate_nesting = 2
            |<->|
enumeration batch event
------------|---|-----
           6|1 1|     a = pyro.sample("a", Categorical(torch.ones(6) / 6))
         2 1|1 1|     b = pyro.sample("b", Bernoulli(p[a]))
            |   |     with pyro.plate("c_plate", 4):
       2 1 1|1 1|         c = pyro.sample("c", Bernoulli(0.3))
            |   |         with pyro.plate("d_plate", 5):
     2 1 1 1|1 1|             d = pyro.sample("d", Bernoulli(0.4))
     2 1 1 1|1 1|1            e_loc = locs[d.long()].unsqueeze(-1)
            |   |7            e_scale = torch.arange(1., 8.)
     2 1 1 1|5 4|7            e = pyro.sample("e", Normal(e_loc, e_scale)
            |   |                             .to_event(1))

要使用枚举语义自动检查此模型,我们可以创建一个枚举跟踪,然后使用 Trace.format_shapes()

[11]:
trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace()
trace.compute_log_prob()  # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
Trace Shapes:
 Param Sites:
            p             6
         locs             2
Sample Sites:
       a dist             |
        value       6 1 1 |
     log_prob       6 1 1 |
       b dist       6 1 1 |
        value     2 1 1 1 |
     log_prob     2 6 1 1 |
 c_plate dist             |
        value           4 |
     log_prob             |
       c dist           4 |
        value   2 1 1 1 1 |
     log_prob   2 1 1 1 4 |
 d_plate dist             |
        value           5 |
     log_prob             |
       d dist         5 4 |
        value 2 1 1 1 1 1 |
     log_prob 2 1 1 1 5 4 |
       e dist 2 1 1 1 5 4 | 7
        value 2 1 1 1 5 4 | 7
     log_prob 2 1 1 1 5 4 |

编写可并行化的代码

编写能正确处理并行化采样站点的 Pyro 模型可能很棘手。有两个技巧可以帮助:广播省略号切片。让我们看一个虚构的模型,了解这些在实践中是如何工作的。我们的目标是编写一个无论是否进行枚举都能工作的模型。

[12]:
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None  # set to either True or False below

def fun(observe):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)

    # Note that the shapes of these sites depend on whether Pyro is enumerating.
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y))
    if enumerated:
        assert x_active.shape  == (2, 1, 1)
        assert y_active.shape  == (2, 1, 1, 1)
    else:
        assert x_active.shape  == (width, 1)
        assert y_active.shape  == (height,)

    # The first trick is to broadcast. This works with or without enumeration.
    p = 0.1 + 0.5 * x_active * y_active
    if enumerated:
        assert p.shape == (2, 2, 1, 1)
    else:
        assert p.shape == (width, height)
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))

    # The second trick is to index using ellipsis slicing.
    # This allows Pyro to add arbitrary dimensions on the left.
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
    if enumerated:
        assert dense_pixels.shape == (2, 2, width, height)
    else:
        assert dense_pixels.shape == (width, height)

    with x_axis, y_axis:
        if observe:
            pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

def guide4():
    fun(observe=False)

# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())

# Test with enumeration.
enumerated = True
test_model(model4, config_enumerate(guide4, "parallel"),
           TraceEnum_ELBO(max_plate_nesting=2))

pyro.plate 内部的自动广播

请注意,在我们所有的模型/guide 规范中,我们都依赖 pyro.plate 自动扩展样本形状,以满足 pyro.sample 语句强加的批量形状约束。然而,这种广播等效于手动注解的 .expand() 语句。

我们将使用 上一节model4 来演示这一点。请注意与之前代码的以下更改

  • 为了本示例的目的,我们只考虑“并行”枚举,但广播在没有枚举或使用“顺序”枚举的情况下应该能按预期工作。

  • 我们已经将返回与活动像素对应的张量的采样函数分离出来。将模型代码模块化为组件是一种常见的做法,有助于维护大型模型。

  • 我们还想使用 pyro.plate 结构来并行化 ELBO 估计器在 num_particles 上的计算。这可以通过将模型/guide 的内容包装在最外层的 pyro.plate 上下文中来实现。

[13]:
num_particles = 100  # Number of samples for the ELBO estimator
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])

def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand([num_particles, width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand([num_particles, 1, height]))
    return x_active, y_active

def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y))
    return x_active, y_active

def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
    with x_axis:
        x_active = pyro.sample("x_active", Bernoulli(p_x).expand([width, 1]))
    with y_axis:
        y_active = pyro.sample("y_active", Bernoulli(p_y).expand([height]))
    return x_active, y_active

def fun(observe, sample_fn):
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.plate('x_axis', width, dim=-2)
    y_axis = pyro.plate('y_axis', height, dim=-1)

    with pyro.plate("num_particles", 100, dim=-3):
        x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
        # Indices corresponding to "parallel" enumeration are appended
        # to the left of the "num_particles" plate dim.
        assert x_active.shape  == (2, 1, 1, 1)
        assert y_active.shape  == (2, 1, 1, 1, 1)
        p = 0.1 + 0.5 * x_active * y_active
        assert p.shape == (2, 2, 1, 1, 1)

        dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
        for x, y in sparse_pixels:
            dense_pixels[..., x, y] = 1
        assert dense_pixels.shape == (2, 2, 1, width, height)

        with x_axis, y_axis:
            if observe:
                pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)

def test_model_with_sample_fn(sample_fn):
    def model():
        fun(observe=True, sample_fn=sample_fn)

    @config_enumerate
    def guide():
        fun(observe=False, sample_fn=sample_fn)

    test_model(model, guide, TraceEnum_ELBO(max_plate_nesting=3))

test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)

在第一个采样函数中,我们必须做一些手动记录工作,并扩展 Bernoulli 分布的批量形状,以考虑由 pyro.plate 上下文添加的条件独立维度。特别是,请注意 sample_pixel_locations 需要知道 num_particleswidthheight,并且正在从全局作用域访问这些变量,这并不理想。

  • pyro.plate 的第二个参数,即可选的 size 参数需要提供给隐式广播,以便它可以推断出每个采样站点的批量形状要求。

  • 采样站点现有的 batch_shape 必须与 pyro.plate 上下文的大小可广播。在我们的特定示例中,Bernoulli(p_x) 具有一个空的批量形状,它是普遍可广播的。

请注意,使用 pyro.plate 通过张量化操作实现并行化是多么简单!pyro.plate 还有助于代码模块化,因为模型组件可以编写成与它们可能随后嵌入的 plate 上下文无关。

[ ]: