Pyro 中的模块¶
本教程介绍 PyroModule,它是 Pyro 对 PyTorch 的 nn.Module 类的贝叶斯扩展。在开始之前,您应该了解 Pyro 模型和推断 的基础知识,理解两个原语 pyro.sample() 和 pyro.param(),并了解 Pyro 效果处理器 的基础知识(例如,通过浏览 minipyro.py)。
摘要:¶
PyroModule 类似于 nn.Module,但允许 Pyro 效果用于采样和约束。
PyroModule
是nn.Module
的一个 mixin 子类,它重写了属性访问(例如.__getattr__()
)。有三种不同的方式可以创建
PyroModule
创建新的子类:
class MyModule(PyroModule): ...
,将现有类 Pyro 化:
MyModule = PyroModule[OtherModule]
,或者将现有
nn.Module
实例就地 Pyro 化:to_pyro_module_(my_module)
。
PyroModule
的通常nn.Parameter
属性变为 Pyro 参数。PyroModule
的参数与 Pyro 的全局参数存储同步。您可以通过创建 PyroParam 对象来添加受约束的参数。
您可以通过创建 PyroSample 对象来添加随机属性。
参数和随机属性会自动命名(无需字符串)。
PyroSample
属性在最外层PyroModule
的每次.__call__()
调用时采样一次。要在
.__call__()
之外的方法上启用 Pyro 效果,请使用 @pyro_method 装饰它们。PyroModule
模型可以包含nn.Module
属性。一个
nn.Module
模型最多可以包含一个PyroModule
属性(参见命名部分)。一个
nn.Module
可以同时包含PyroModule
模型和PyroModule
guide(例如 Predictive)。
目录¶
[1]:
import os
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')
PyroModule
如何工作¶
PyroModule
旨在将 Pyro 的原语和效果处理器与 PyTorch 的 nn.Module 惯用法结合起来,从而实现对现有 nn.Module
的贝叶斯处理,并支持通过 jit.trace_module 进行模型服务。在开始使用 PyroModule
之前,了解它们的工作原理将有所帮助,以便您可以避免陷阱。
PyroModule
是 nn.Module
的子类。PyroModule
通过在模块属性访问时插入效果处理逻辑来启用 Pyro 效果,重写了 .__getattr__()
、.__setattr__()
和 .__delattr__()
方法。此外,由于某些效果(如采样)在每次模型调用时仅应用一次,PyroModule
重写了 .__call__()
方法以确保在每次 .__call__()
调用时最多生成一次样本(注意 nn.Module
子类通常实现一个由 .__call__()
调用的 .forward()
方法)。
如何创建 PyroModule
¶
有三种方式可以创建 PyroModule
。让我们从一个不是 PyroModule
的 nn.Module
开始
[2]:
class Linear(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_size, out_size))
self.bias = nn.Parameter(torch.randn(out_size))
def forward(self, input_):
return self.bias + input_ @ self.weight
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
创建 PyroModule
的第一种方式是创建 PyroModule
的子类。您可以更新您编写的任何 nn.Module
使其成为 PyroModule,例如
- class Linear(nn.Module):
+ class Linear(PyroModule):
def __init__(self, in_size, out_size):
super().__init__()
self.weight = ...
self.bias = ...
...
或者,如果您想使用像上面的 Linear
这样的第三方代码,您可以使用 PyroModule
作为 mixin 类来创建其子类
[3]:
class PyroLinear(Linear, PyroModule):
pass
linear = PyroLinear(5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
创建 PyroModule
的第二种方式是使用括号语法 PyroModule[-]
来自动表示如上所示的简单 mixin 类。
- linear = Linear(5, 2)
+ linear = PyroModule[Linear](5, 2)
在我们的例子中,我们可以这样写
[4]:
linear = PyroModule[Linear](5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
手动子类化和使用 PyroModule[-]
之间的一个区别是,PyroModule[-]
还确保所有 nn.Module
超类也成为 PyroModule
,这对于库代码中的类层次结构很重要。例如,由于 nn.GRU
是 nn.RNN
的子类,因此 PyroModule[nn.GRU]
也将是 PyroModule[nn.RNN]
的子类。
创建 PyroModule
的第三种方式是使用 to_pyro_module_() 就地更改现有 nn.Module
实例的类型。如果您使用第三方模块工厂助手或更新现有脚本,这非常有用,例如
[5]:
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)
to_pyro_module_(linear) # this operates in-place
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
效果如何工作¶
到目前为止,我们已经创建了 PyroModule
,但还没有利用 Pyro 效果。但我们的 PyroModule
的 nn.Parameter
属性已经像 pyro.param 语句一样:它们与 Pyro 的参数存储同步,并可以记录在 traces 中。
[6]:
pyro.clear_param_store()
# This is not traced:
linear = Linear(5, 2)
with poutine.trace() as tr:
linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))
# Now this is traced:
to_pyro_module_(linear)
with poutine.trace() as tr:
linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))
Linear
[]
[]
PyroLinear
['bias', 'weight']
['bias', 'weight']
如何约束参数¶
Pyro 参数允许约束,我们经常希望我们的 nn.Module
参数遵守约束。您可以通过将 nn.Parameter
替换为 PyroParam 属性来约束 PyroModule
的参数。例如,为确保 .bias
属性为正,我们可以将其设置为
[7]:
print("params before:", [name for name, _ in linear.named_parameters()])
linear.bias = PyroParam(torch.randn(2).exp(), constraint=constraints.positive)
print("params after:", [name for name, _ in linear.named_parameters()])
print("bias:", linear.bias)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
params before: ['weight', 'bias']
params after: ['weight', 'bias_unconstrained']
bias: tensor([0.9777, 0.8773], grad_fn=<AddBackward0>)
现在 PyTorch 将优化 .bias_unconstrained
参数,每次我们访问 .bias
属性时,它将读取并转换 .bias_unconstrained
参数(类似于 Python 的 @property
)。
如果您事先知道约束,可以将其构建到模块构造函数中,例如
class Linear(PyroModule):
def __init__(self, in_size, out_size):
super().__init__()
self.weight = ...
- self.bias = nn.Parameter(torch.randn(out_size))
+ self.bias = PyroParam(torch.randn(out_size).exp(),
+ constraint=constraints.positive)
...
如何使 PyroModule
具有贝叶斯特性¶
到目前为止,我们的 Linear
模块仍然是确定性的。为了使其随机化并具有贝叶斯特性,我们将用 PyroSample 属性替换 nn.Parameter
和 PyroParam
属性,并指定一个 prior。让我们在权重上设置一个简单的 prior,注意将其形状扩展到 [5,2]
并使用 .to_event() 声明事件维度(如张量形状教程中解释的那样)。
[8]:
print("params before:", [name for name, _ in linear.named_parameters()])
linear.weight = PyroSample(dist.Normal(0, 1).expand([5, 2]).to_event(2))
print("params after:", [name for name, _ in linear.named_parameters()])
print("weight:", linear.weight)
print("weight:", linear.weight)
example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
params before: ['weight', 'bias_unconstrained']
params after: ['bias_unconstrained']
weight: tensor([[-0.8668, -0.0150],
[ 3.4642, 1.9076],
[ 0.4717, 1.0565],
[-1.2032, 1.0821],
[-0.1712, 0.4711]])
weight: tensor([[-1.2577, -0.5242],
[-0.7785, -1.0806],
[ 0.6239, -0.4884],
[-0.2580, -1.2288],
[-0.7540, -1.9375]])
请注意,.weight
参数现在消失了,并且每次我们调用 linear()
时,都会从 prior 中采样一个新的权重。实际上,权重是在 Linear.forward()
访问 .weight
属性时采样的:该属性现在具有从 prior 采样的特殊行为。
我们可以看到跟踪中出现的所有 Pyro 效果
[9]:
with poutine.trace() as tr:
linear(example_input)
for site in tr.trace.nodes.values():
print(site["type"], site["name"], site["value"])
param bias tensor([0.9777, 0.8773], grad_fn=<AddBackward0>)
sample weight tensor([[ 1.8043, 1.5494],
[ 0.0128, 1.4100],
[-0.2155, 0.6375],
[ 1.1202, 1.9672],
[-0.1576, -0.6957]])
到目前为止,我们已经修改了一个第三方模块,使其具有贝叶斯特性
linear = Linear(...)
to_pyro_module_(linear)
linear.bias = PyroParam(...)
linear.weight = PyroSample(...)
如果您从头开始创建模型,可以 instead 定义一个新的类
[10]:
class BayesianLinear(PyroModule):
def __init__(self, in_size, out_size):
super().__init__()
self.bias = PyroSample(
prior=dist.LogNormal(0, 1).expand([out_size]).to_event(1))
self.weight = PyroSample(
prior=dist.Normal(0, 1).expand([in_size, out_size]).to_event(2))
def forward(self, input):
return self.bias + input @ self.weight # this line samples bias and weight
请注意,每次 .__call__()
调用时最多采样一次,例如
class BayesianLinear(PyroModule):
...
def forward(self, input):
weight1 = self.weight # Draws a sample.
weight2 = self.weight # Reads previous sample.
assert weight2 is weight1 # All accesses should agree.
...
⚠ 注意:在 plate 内部访问属性¶
由于 PyroSample
和 PyroParam
属性受 Pyro 效果影响,我们需要注意参数访问的位置。例如,pyro.plate 上下文可以改变采样和参数 site 的形状。考虑一个具有一个潜变量和批量观测语句的模型。我们看到这两个模型之间唯一的区别在于 .loc
属性的访问位置。
[11]:
class NormalModel(PyroModule):
def __init__(self):
super().__init__()
self.loc = PyroSample(dist.Normal(0, 1))
class GlobalModel(NormalModel):
def forward(self, data):
# If .loc is accessed (for the first time) outside the plate,
# then it will have empty shape ().
loc = self.loc
assert loc.shape == ()
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(loc, 1), obs=data)
class LocalModel(NormalModel):
def forward(self, data):
with pyro.plate("data", len(data)):
# If .loc is accessed (for the first time) inside the plate,
# then it will be expanded by the plate to shape (plate.size,).
loc = self.loc
assert loc.shape == (len(data),)
pyro.sample("obs", dist.Normal(loc, 1), obs=data)
data = torch.randn(10)
LocalModel()(data)
GlobalModel()(data)
如何创建复杂的嵌套 PyroModule
¶
为了对上面的 BayesianLinear
模块进行推断,我们需要将其包装在一个带有似然的概率模型中;这个包装器也将是一个 PyroModule
。
[12]:
class Model(PyroModule):
def __init__(self, in_size, out_size):
super().__init__()
self.linear = BayesianLinear(in_size, out_size) # this is a PyroModule
self.obs_scale = PyroSample(dist.LogNormal(0, 1))
def forward(self, input, output=None):
obs_loc = self.linear(input) # this samples linear.bias and linear.weight
obs_scale = self.obs_scale # this samples self.obs_scale
with pyro.plate("instances", len(input)):
return pyro.sample("obs", dist.Normal(obs_loc, obs_scale).to_event(1),
obs=output)
通常 nn.Module
可以使用简单的 PyTorch 优化器进行训练,而 Pyro 模型需要概率推断,例如使用 SVI 和 AutoNormal guide。详情请参阅贝叶斯回归教程。
[13]:
%%time
pyro.clear_param_store()
pyro.set_rng_seed(1)
model = Model(5, 2)
x = torch.randn(100, 5)
y = model(x)
guide = AutoNormal(model)
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
for step in range(2 if smoke_test else 501):
loss = svi.step(x, y) / y.numel()
if step % 100 == 0:
print("step {} loss = {:0.4g}".format(step, loss))
step 0 loss = 7.186
step 100 loss = 2.185
step 200 loss = 1.87
step 300 loss = 1.739
step 400 loss = 1.691
step 500 loss = 1.673
CPU times: user 2.35 s, sys: 24.8 ms, total: 2.38 s
Wall time: 2.39 s
PyroSample
语句也可能依赖于其他采样语句或参数。在这种情况下,prior
可以是依赖于 self
的可调用对象,而不是一个固定的分布。例如,考虑分层模型
[14]:
class Model(PyroModule):
def __init__(self):
super().__init__()
self.dof = PyroSample(dist.Gamma(3, 1))
self.loc = PyroSample(dist.Normal(0, 1))
self.scale = PyroSample(lambda self: dist.InverseGamma(self.dof, 1))
self.x = PyroSample(lambda self: dist.Normal(self.loc, self.scale))
def forward(self):
return self.x
Model()()
[14]:
tensor(0.5387)
命名如何工作¶
在上面的代码中,我们看到了嵌入在另一个 Model
中的 BayesianLinear
模型。两者都是 PyroModule
。简单的 pyro.sample 语句需要名称字符串,而 PyroModule
属性会自动处理命名。让我们看看上面的 model
和 guide
(因为 AutoNormal
也是一个 PyroModule
)是如何工作的。
让我们跟踪模型和 guide 的执行。
[15]:
with poutine.trace() as tr:
model(x)
for site in tr.trace.nodes.values():
print(site["type"], site["name"], site["value"].shape)
sample linear.bias torch.Size([2])
sample linear.weight torch.Size([5, 2])
sample obs_scale torch.Size([])
sample instances torch.Size([100])
sample obs torch.Size([100, 2])
请注意,model.linear.bias
对应于 linear.bias
名称,对于 model.linear.weight
和 model.obs_scale
属性也是如此。“instances” site 对应于 plate,“obs” site 对应于似然。接下来查看 guide
[16]:
with poutine.trace() as tr:
guide(x)
for site in tr.trace.nodes.values():
print(site["type"], site["name"], site["value"].shape)
param AutoNormal.locs.linear.bias torch.Size([2])
param AutoNormal.scales.linear.bias torch.Size([2])
sample linear.bias_unconstrained torch.Size([2])
sample linear.bias torch.Size([2])
param AutoNormal.locs.linear.weight torch.Size([5, 2])
param AutoNormal.scales.linear.weight torch.Size([5, 2])
sample linear.weight_unconstrained torch.Size([5, 2])
sample linear.weight torch.Size([5, 2])
param AutoNormal.locs.obs_scale torch.Size([])
param AutoNormal.scales.obs_scale torch.Size([])
sample obs_scale_unconstrained torch.Size([])
sample obs_scale torch.Size([])
我们看到 guide 学习了三个随机变量的后验:linear.bias
、linear.weight
和 obs_scale
。对于每个变量,guide 学习了一对 (loc,scale)
参数,这些参数内部存储在嵌套的 PyroModule
中
class AutoNormal(...):
def __init__(self, ...):
self.locs = PyroModule()
self.scales = PyroModule()
...
最后,AutoNormal
包含一个对应每个无约束潜 site 的 pyro.sample
语句,其后是一个 pyro.deterministic 语句,用于将无约束样本映射到受约束的后验样本。
⚠ 注意:避免重复名称¶
PyroModule
会自动为其属性命名,即使是嵌套在其他 PyroModule
中的深层属性。然而,在混合使用普通 nn.Module
和 PyroModule
时必须小心,因为 nn.Module
不支持自动 site 命名。
在单个模型(或 guide)中
如果只有一个 PyroModule
,那么是安全的。
class Model(nn.Module): # not a PyroModule
def __init__(self):
self.x = PyroModule()
- self.y = PyroModule() # Could lead to name conflict.
+ self.y = nn.Module() # Has no Pyro names, so avoids conflict.
如果只有两个 PyroModule
,那么其中一个必须是另一个的属性。
class Model(PyroModule):
def __init__(self):
self.x = PyroModule() # ok
如果您有两个互不为属性的 PyroModule
,那么它们必须通过其他 PyroModule
的属性链接连接起来。这些可以是同级链接
- class Model(nn.Module): # Could lead to name conflict.
+ class Model(PyroModule): # Ensures names are unique.
def __init__(self):
self.x = PyroModule()
self.y = PyroModule()
或祖先链接
class Model(PyroModule):
def __init__(self):
- self.x = nn.Module() # Could lead to name conflict.
+ self.x = PyroModule() # Ensures y is conected to root Model.
self.x.y = PyroModule()
有时您可能希望将 (model,guide)
对存储在单个 nn.Module
中,例如为了从 C++ 提供服务。在这种情况下,将它们作为容器 nn.Module
的属性是安全的,但该容器不应该是一个 PyroModule
。
class Container(nn.Module): # This cannot be a PyroModule.
def __init__(self, model, guide): # These may be PyroModules.
super().__init__()
self.model = model
self.guide = guide
# This is a typical trace-replay pattern seen in model serving.
def forward(self, data):
tr = poutine.trace(self.guide).get_trace(data)
return poutine.replay(model, tr)(data)
[ ]: