使用 EasyGuide 编写指南¶
本教程介绍 pyro.contrib.easyguide 模块。本教程假设读者已经熟悉 SVI 和 张量形状。
总结¶
对于简单的黑箱指南,请尝试使用 pyro.infer.autoguide 中的组件。
对于更复杂的指南,请尝试使用 pyro.contrib.easyguide 中的组件。
使用
@easy_guide(model)
进行装饰。使用
group = self.group(match="my_regex")
选择多个模型站点。使用
group.sample(...)
通过单个分布指导一组站点。使用
group.batch_shape
,group.event_shape
等检查连接后的组形状。使用
self.plate(...)
代替pyro.plate(...)
。为与子采样兼容,请将
event_dim
参数传递给pyro.param(...)
。要对模型站点“foo”进行 MAP 估计,请使用
foo = self.map_estimate("foo")
。
目录¶
[ ]:
import os
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.contrib.easyguide import easy_guide
from pyro.optim import Adam
from torch.distributions import constraints
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')
时序数据建模¶
考虑一个时序模型,该模型具有缓慢变化的连续潜状态和带有 Logistic 链接函数的 Bernoulli 观测值。
[ ]:
def model(batch, subsample, full_size):
batch = list(batch)
num_time_steps = len(batch)
drift = pyro.sample("drift", dist.LogNormal(-1, 0.5))
with pyro.plate("data", full_size, subsample=subsample):
z = 0.
for t in range(num_time_steps):
z = pyro.sample("state_{}".format(t),
dist.Normal(z, drift))
batch[t] = pyro.sample("obs_{}".format(t),
dist.Bernoulli(logits=z),
obs=batch[t])
return torch.stack(batch)
我们直接从模型生成一些数据。
[ ]:
full_size = 100
num_time_steps = 7
pyro.set_rng_seed(123456789)
data = model([None] * num_time_steps, torch.arange(full_size), full_size)
assert data.shape == (num_time_steps, full_size)
不使用 EasyGuide 编写指南¶
考虑一个可能的指南用于此模型,其中我们使用 Delta
分布对 drift
参数进行点估计,然后使用共享的不确定性但局部均值来建模局部时序,使用 LowRankMultivariateNormal
分布。有一个全局采样站点,我们可以用 param
和 sample
语句建模。然后我们采样一对全局不确定性参数 cov_diag
和 cov_factor
。接下来我们使用 pyro.param(..., event_dim=...)
和一个辅助采样站点采样局部 loc
参数。最后我们将该辅助站点解包为每个时序一个元素。辅助站点解包为 Delta
的模式非常常见。
[ ]:
rank = 3
def guide(batch, subsample, full_size):
num_time_steps, batch_size = batch.shape
# MAP estimate the drift.
drift_loc = pyro.param("drift_loc", lambda: torch.tensor(0.1),
constraint=constraints.positive)
pyro.sample("drift", dist.Delta(drift_loc))
# Model local states using shared uncertainty + local mean.
cov_diag = pyro.param("state_cov_diag",
lambda: torch.full((num_time_steps,), 0.01),
constraint=constraints.positive)
cov_factor = pyro.param("state_cov_factor",
lambda: torch.randn(num_time_steps, rank) * 0.01)
with pyro.plate("data", full_size, subsample=subsample):
# Sample local mean.
loc = pyro.param("state_loc",
lambda: torch.full((full_size, num_time_steps), 0.5),
event_dim=1)
states = pyro.sample("states",
dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag),
infer={"is_auxiliary": True})
# Unpack the joint states into one sample site per time step.
for t in range(num_time_steps):
pyro.sample("state_{}".format(t), dist.Delta(states[:, t]))
让我们使用 SVI 和 Trace_ELBO 进行训练,手动将数据分成小批量。
[ ]:
def train(guide, num_epochs=1 if smoke_test else 101, batch_size=20):
full_size = data.size(-1)
pyro.get_param_store().clear()
pyro.set_rng_seed(123456789)
svi = SVI(model, guide, Adam({"lr": 0.02}), Trace_ELBO())
for epoch in range(num_epochs):
pos = 0
losses = []
while pos < full_size:
subsample = torch.arange(pos, pos + batch_size)
batch = data[:, pos:pos + batch_size]
pos += batch_size
losses.append(svi.step(batch, subsample, full_size=full_size))
epoch_loss = sum(losses) / len(losses)
if epoch % 10 == 0:
print("epoch {} loss = {}".format(epoch, epoch_loss / data.numel()))
[ ]:
train(guide)
使用 EasyGuide¶
现在让我们使用 @easy_guide
装饰器进行简化。修改如下:1. 使用 @easy_guide
装饰并添加 self
参数。2. 将 Delta
指南替换为简单的 map_estimate()
。3. 选择一组模型站点 group
并读取它们连接后的 event_shape
。4. 将辅助站点和 Delta
切片替换为单个 group.sample()
。
[ ]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
# MAP estimate the drift.
self.map_estimate("drift")
# Model local states using shared uncertainty + local mean.
group = self.group(match="state_[0-9]*") # Selects all local variables.
cov_diag = pyro.param("state_cov_diag",
lambda: torch.full(group.event_shape, 0.01),
constraint=constraints.positive)
cov_factor = pyro.param("state_cov_factor",
lambda: torch.randn(group.event_shape + (rank,)) * 0.01)
with self.plate("data", full_size, subsample=subsample):
# Sample local mean.
loc = pyro.param("state_loc",
lambda: torch.full((full_size,) + group.event_shape, 0.5),
event_dim=1)
# Automatically sample the joint latent, then unpack and replay model sites.
group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
注意,我们使用了 group.event_shape
来确定组中所有匹配站点的总展平连接形状。
[ ]:
train(guide)
分摊指南¶
EasyGuide
也使得编写分摊指南(即学习一个函数来从数据预测潜变量,而不是为每个数据点学习一个参数的指南)变得容易。让我们修改上一个指南,将潜变量 loc
预测为观测数据的仿射函数,而不是记忆每个数据点的潜变量。这种分摊指南在实践中更有用,因为它可以处理新数据。
[ ]:
@easy_guide(model)
def guide(self, batch, subsample, full_size):
num_time_steps, batch_size = batch.shape
self.map_estimate("drift")
group = self.group(match="state_[0-9]*")
cov_diag = pyro.param("state_cov_diag",
lambda: torch.full(group.event_shape, 0.01),
constraint=constraints.positive)
cov_factor = pyro.param("state_cov_factor",
lambda: torch.randn(group.event_shape + (rank,)) * 0.01)
# Predict latent propensity as an affine function of observed data.
if not hasattr(self, "nn"):
self.nn = torch.nn.Linear(group.event_shape.numel(), group.event_shape.numel())
self.nn.weight.data.fill_(1.0 / num_time_steps)
self.nn.bias.data.fill_(-0.5)
pyro.module("state_nn", self.nn)
with self.plate("data", full_size, subsample=subsample):
loc = self.nn(batch.t())
group.sample("states", dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag))
[ ]:
train(guide)
[ ]: