使用 EasyGuide 编写指南

本教程介绍 pyro.contrib.easyguide 模块。本教程假设读者已经熟悉 SVI张量形状

总结

  • 对于简单的黑箱指南,请尝试使用 pyro.infer.autoguide 中的组件。

  • 对于更复杂的指南,请尝试使用 pyro.contrib.easyguide 中的组件。

  • 使用 @easy_guide(model) 进行装饰。

  • 使用 group = self.group(match="my_regex") 选择多个模型站点。

  • 使用 group.sample(...) 通过单个分布指导一组站点。

  • 使用 group.batch_shape, group.event_shape 等检查连接后的组形状。

  • 使用 self.plate(...) 代替 pyro.plate(...)

  • 为与子采样兼容,请将 event_dim 参数传递给 pyro.param(...)

  • 要对模型站点“foo”进行 MAP 估计,请使用 foo = self.map_estimate("foo")

目录

[ ]:
import os
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.contrib.easyguide import easy_guide
from pyro.optim import Adam
from torch.distributions import constraints

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')

时序数据建模

考虑一个时序模型,该模型具有缓慢变化的连续潜状态和带有 Logistic 链接函数的 Bernoulli 观测值。

[ ]:
def model(batch, subsample, full_size):
    batch = list(batch)
    num_time_steps = len(batch)
    drift = pyro.sample("drift", dist.LogNormal(-1, 0.5))
    with pyro.plate("data", full_size, subsample=subsample):
        z = 0.
        for t in range(num_time_steps):
            z = pyro.sample("state_{}".format(t),
                            dist.Normal(z, drift))
            batch[t] = pyro.sample("obs_{}".format(t),
                                   dist.Bernoulli(logits=z),
                                   obs=batch[t])
    return torch.stack(batch)

我们直接从模型生成一些数据。

[ ]:
full_size = 100
num_time_steps = 7
pyro.set_rng_seed(123456789)
data = model([None] * num_time_steps, torch.arange(full_size), full_size)
assert data.shape == (num_time_steps, full_size)

不使用 EasyGuide 编写指南

考虑一个可能的指南用于此模型,其中我们使用 Delta 分布对 drift 参数进行点估计,然后使用共享的不确定性但局部均值来建模局部时序,使用 LowRankMultivariateNormal 分布。有一个全局采样站点,我们可以用 paramsample 语句建模。然后我们采样一对全局不确定性参数 cov_diagcov_factor。接下来我们使用 pyro.param(..., event_dim=...) 和一个辅助采样站点采样局部 loc 参数。最后我们将该辅助站点解包为每个时序一个元素。辅助站点解包为 Delta 的模式非常常见。

[ ]:
rank = 3

def guide(batch, subsample, full_size):
    num_time_steps, batch_size = batch.shape

    # MAP estimate the drift.
    drift_loc = pyro.param("drift_loc", lambda: torch.tensor(0.1),
                           constraint=constraints.positive)
    pyro.sample("drift", dist.Delta(drift_loc))

    # Model local states using shared uncertainty + local mean.
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full((num_time_steps,), 0.01),
                         constraint=constraints.positive)
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.randn(num_time_steps, rank) * 0.01)
    with pyro.plate("data", full_size, subsample=subsample):
        # Sample local mean.
        loc = pyro.param("state_loc",
                         lambda: torch.full((full_size, num_time_steps), 0.5),
                         event_dim=1)
        states = pyro.sample("states",
                             dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),
                             infer={"is_auxiliary": True})
        # Unpack the joint states into one sample site per time step.
        for t in range(num_time_steps):
            pyro.sample("state_{}".format(t), dist.Delta(states[:, t]))

让我们使用 SVITrace_ELBO 进行训练,手动将数据分成小批量。

[ ]:
def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):
    full_size = data.size(-1)
    pyro.get_param_store().clear()
    pyro.set_rng_seed(123456789)
    svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
    for epoch in range(num_epochs):
        pos = 0
        losses = []
        while pos < full_size:
            subsample = torch.arange(pos, pos + batch_size)
            batch = data[:, pos:pos + batch_size]
            pos += batch_size
            losses.append(svi.step(batch, subsample, full_size=full_size))
        epoch_loss = sum(losses) / len(losses)
        if epoch % 10 == 0:
            print("epoch {} loss = {}".format(epoch, epoch_loss / data.numel()))
[ ]:
train(guide)

使用 EasyGuide

现在让我们使用 @easy_guide 装饰器进行简化。修改如下:1. 使用 @easy_guide 装饰并添加 self 参数。2. 将 Delta 指南替换为简单的 map_estimate()。3. 选择一组模型站点 group 并读取它们连接后的 event_shape。4. 将辅助站点和 Delta 切片替换为单个 group.sample()

[ ]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
    # MAP estimate the drift.
    self.map_estimate("drift")

    # Model local states using shared uncertainty + local mean.
    group = self.group(match="state_[0-9]*")  # Selects all local variables.
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01),
                          constraint=constraints.positive)
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.randn(group.event_shape + (rank,)) * 0.01)
    with self.plate("data", full_size, subsample=subsample):
        # Sample local mean.
        loc = pyro.param("state_loc",
                         lambda: torch.full((full_size,) + group.event_shape, 0.5),
                         event_dim=1)
        # Automatically sample the joint latent, then unpack and replay model sites.
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))

注意,我们使用了 group.event_shape 来确定组中所有匹配站点的总展平连接形状。

[ ]:
train(guide)

分摊指南

EasyGuide 也使得编写分摊指南(即学习一个函数来从数据预测潜变量,而不是为每个数据点学习一个参数的指南)变得容易。让我们修改上一个指南,将潜变量 loc 预测为观测数据的仿射函数,而不是记忆每个数据点的潜变量。这种分摊指南在实践中更有用,因为它可以处理新数据。

[ ]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
    num_time_steps, batch_size = batch.shape
    self.map_estimate("drift")

    group = self.group(match="state_[0-9]*")
    cov_diag = pyro.param("state_cov_diag",
                          lambda: torch.full(group.event_shape, 0.01),
                          constraint=constraints.positive)
    cov_factor = pyro.param("state_cov_factor",
                            lambda: torch.randn(group.event_shape + (rank,)) * 0.01)

    # Predict latent propensity as an affine function of observed data.
    if not hasattr(self, "nn"):
        self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())
        self.nn.weight.data.fill_(1.0 / num_time_steps)
        self.nn.bias.data.fill_(-0.5)
    pyro.module("state_nn", self.nn)
    with self.plate("data", full_size, subsample=subsample):
        loc = self.nn(batch.t())
        group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
[ ]:
train(guide)
[ ]: