Pyro 中的模块

本教程介绍 PyroModule,它是 Pyro 对 PyTorch 的 nn.Module 类的贝叶斯扩展。在开始之前,您应该了解 Pyro 模型和推断 的基础知识,理解两个原语 pyro.sample()pyro.param(),并了解 Pyro 效果处理器 的基础知识(例如,通过浏览 minipyro.py)。

摘要:

  • PyroModule 类似于 nn.Module,但允许 Pyro 效果用于采样和约束。

  • PyroModulenn.Module 的一个 mixin 子类,它重写了属性访问(例如 .__getattr__())。

  • 有三种不同的方式可以创建 PyroModule

    • 创建新的子类:class MyModule(PyroModule): ...

    • 将现有类 Pyro 化:MyModule = PyroModule[OtherModule],或者

    • 将现有 nn.Module 实例就地 Pyro 化:to_pyro_module_(my_module)

  • PyroModule 的通常 nn.Parameter 属性变为 Pyro 参数。

  • PyroModule 的参数与 Pyro 的全局参数存储同步。

  • 您可以通过创建 PyroParam 对象来添加受约束的参数。

  • 您可以通过创建 PyroSample 对象来添加随机属性。

  • 参数和随机属性会自动命名(无需字符串)。

  • PyroSample 属性在最外层 PyroModule 的每次 .__call__() 调用时采样一次。

  • 要在 .__call__() 之外的方法上启用 Pyro 效果,请使用 @pyro_method 装饰它们。

  • PyroModule 模型可以包含 nn.Module 属性。

  • 一个 nn.Module 模型最多可以包含一个 PyroModule 属性(参见命名部分)。

  • 一个 nn.Module 可以同时包含 PyroModule 模型和 PyroModule guide(例如 Predictive)。

目录

[1]:
import os
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from torch.distributions import constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.nn.module import to_pyro_module_
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')

PyroModule 如何工作

PyroModule 旨在将 Pyro 的原语和效果处理器与 PyTorch 的 nn.Module 惯用法结合起来,从而实现对现有 nn.Module 的贝叶斯处理,并支持通过 jit.trace_module 进行模型服务。在开始使用 PyroModule 之前,了解它们的工作原理将有所帮助,以便您可以避免陷阱。

PyroModulenn.Module 的子类。PyroModule 通过在模块属性访问时插入效果处理逻辑来启用 Pyro 效果,重写了 .__getattr__().__setattr__().__delattr__() 方法。此外,由于某些效果(如采样)在每次模型调用时仅应用一次,PyroModule 重写了 .__call__() 方法以确保在每次 .__call__() 调用时最多生成一次样本(注意 nn.Module 子类通常实现一个由 .__call__() 调用的 .forward() 方法)。

如何创建 PyroModule

有三种方式可以创建 PyroModule。让我们从一个不是 PyroModulenn.Module 开始

[2]:
class Linear(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_size, out_size))
        self.bias = nn.Parameter(torch.randn(out_size))

    def forward(self, input_):
        return self.bias + input_ @ self.weight

linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

创建 PyroModule 的第一种方式是创建 PyroModule 的子类。您可以更新您编写的任何 nn.Module 使其成为 PyroModule,例如

- class Linear(nn.Module):
+ class Linear(PyroModule):
      def __init__(self, in_size, out_size):
          super().__init__()
          self.weight = ...
          self.bias = ...
      ...

或者,如果您想使用像上面的 Linear 这样的第三方代码,您可以使用 PyroModule 作为 mixin 类来创建其子类

[3]:
class PyroLinear(Linear, PyroModule):
    pass

linear = PyroLinear(5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

创建 PyroModule 的第二种方式是使用括号语法 PyroModule[-] 来自动表示如上所示的简单 mixin 类。

- linear = Linear(5, 2)
+ linear = PyroModule[Linear](5, 2)

在我们的例子中,我们可以这样写

[4]:
linear = PyroModule[Linear](5, 2)
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

手动子类化和使用 PyroModule[-] 之间的一个区别是,PyroModule[-] 还确保所有 nn.Module 超类也成为 PyroModule,这对于库代码中的类层次结构很重要。例如,由于 nn.GRUnn.RNN 的子类,因此 PyroModule[nn.GRU] 也将是 PyroModule[nn.RNN] 的子类。

创建 PyroModule 的第三种方式是使用 to_pyro_module_() 就地更改现有 nn.Module 实例的类型。如果您使用第三方模块工厂助手或更新现有脚本,这非常有用,例如

[5]:
linear = Linear(5, 2)
assert isinstance(linear, nn.Module)
assert not isinstance(linear, PyroModule)

to_pyro_module_(linear)  # this operates in-place
assert isinstance(linear, nn.Module)
assert isinstance(linear, Linear)
assert isinstance(linear, PyroModule)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)

效果如何工作

到目前为止,我们已经创建了 PyroModule,但还没有利用 Pyro 效果。但我们的 PyroModulenn.Parameter 属性已经像 pyro.param 语句一样:它们与 Pyro 的参数存储同步,并可以记录在 traces 中。

[6]:
pyro.clear_param_store()

# This is not traced:
linear = Linear(5, 2)
with poutine.trace() as tr:
    linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))

# Now this is traced:
to_pyro_module_(linear)
with poutine.trace() as tr:
    linear(example_input)
print(type(linear).__name__)
print(list(tr.trace.nodes.keys()))
print(list(pyro.get_param_store().keys()))
Linear
[]
[]
PyroLinear
['bias', 'weight']
['bias', 'weight']

如何约束参数

Pyro 参数允许约束,我们经常希望我们的 nn.Module 参数遵守约束。您可以通过将 nn.Parameter 替换为 PyroParam 属性来约束 PyroModule 的参数。例如,为确保 .bias 属性为正,我们可以将其设置为

[7]:
print("params before:", [name for name, _ in linear.named_parameters()])

linear.bias = PyroParam(torch.randn(2).exp(), constraint=constraints.positive)
print("params after:", [name for name, _ in linear.named_parameters()])
print("bias:", linear.bias)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
params before: ['weight', 'bias']
params after: ['weight', 'bias_unconstrained']
bias: tensor([0.9777, 0.8773], grad_fn=<AddBackward0>)

现在 PyTorch 将优化 .bias_unconstrained 参数,每次我们访问 .bias 属性时,它将读取并转换 .bias_unconstrained 参数(类似于 Python 的 @property)。

如果您事先知道约束,可以将其构建到模块构造函数中,例如

  class Linear(PyroModule):
      def __init__(self, in_size, out_size):
          super().__init__()
          self.weight = ...
-         self.bias = nn.Parameter(torch.randn(out_size))
+         self.bias = PyroParam(torch.randn(out_size).exp(),
+                               constraint=constraints.positive)
      ...

如何使 PyroModule 具有贝叶斯特性

到目前为止,我们的 Linear 模块仍然是确定性的。为了使其随机化并具有贝叶斯特性,我们将用 PyroSample 属性替换 nn.ParameterPyroParam 属性,并指定一个 prior。让我们在权重上设置一个简单的 prior,注意将其形状扩展到 [5,2] 并使用 .to_event() 声明事件维度(如张量形状教程中解释的那样)。

[8]:
print("params before:", [name for name, _ in linear.named_parameters()])

linear.weight = PyroSample(dist.Normal(0, 1).expand([5, 2]).to_event(2))
print("params after:", [name for name, _ in linear.named_parameters()])
print("weight:", linear.weight)
print("weight:", linear.weight)

example_input = torch.randn(100, 5)
example_output = linear(example_input)
assert example_output.shape == (100, 2)
params before: ['weight', 'bias_unconstrained']
params after: ['bias_unconstrained']
weight: tensor([[-0.8668, -0.0150],
        [ 3.4642,  1.9076],
        [ 0.4717,  1.0565],
        [-1.2032,  1.0821],
        [-0.1712,  0.4711]])
weight: tensor([[-1.2577, -0.5242],
        [-0.7785, -1.0806],
        [ 0.6239, -0.4884],
        [-0.2580, -1.2288],
        [-0.7540, -1.9375]])

请注意,.weight 参数现在消失了,并且每次我们调用 linear() 时,都会从 prior 中采样一个新的权重。实际上,权重是在 Linear.forward() 访问 .weight 属性时采样的:该属性现在具有从 prior 采样的特殊行为。

我们可以看到跟踪中出现的所有 Pyro 效果

[9]:
with poutine.trace() as tr:
    linear(example_input)
for site in tr.trace.nodes.values():
    print(site["type"], site["name"], site["value"])
param bias tensor([0.9777, 0.8773], grad_fn=<AddBackward0>)
sample weight tensor([[ 1.8043,  1.5494],
        [ 0.0128,  1.4100],
        [-0.2155,  0.6375],
        [ 1.1202,  1.9672],
        [-0.1576, -0.6957]])

到目前为止,我们已经修改了一个第三方模块,使其具有贝叶斯特性

linear = Linear(...)
to_pyro_module_(linear)
linear.bias = PyroParam(...)
linear.weight = PyroSample(...)

如果您从头开始创建模型,可以 instead 定义一个新的类

[10]:
class BayesianLinear(PyroModule):
    def __init__(self, in_size, out_size):
       super().__init__()
       self.bias = PyroSample(
           prior=dist.LogNormal(0, 1).expand([out_size]).to_event(1))
       self.weight = PyroSample(
           prior=dist.Normal(0, 1).expand([in_size, out_size]).to_event(2))

    def forward(self, input):
        return self.bias + input @ self.weight  # this line samples bias and weight

请注意,每次 .__call__() 调用时最多采样一次,例如

class BayesianLinear(PyroModule):
    ...
    def forward(self, input):
        weight1 = self.weight      # Draws a sample.
        weight2 = self.weight      # Reads previous sample.
        assert weight2 is weight1  # All accesses should agree.
        ...

⚠ 注意:在 plate 内部访问属性

由于 PyroSamplePyroParam 属性受 Pyro 效果影响,我们需要注意参数访问的位置。例如,pyro.plate 上下文可以改变采样和参数 site 的形状。考虑一个具有一个潜变量和批量观测语句的模型。我们看到这两个模型之间唯一的区别在于 .loc 属性的访问位置。

[11]:
class NormalModel(PyroModule):
    def __init__(self):
        super().__init__()
        self.loc = PyroSample(dist.Normal(0, 1))

class GlobalModel(NormalModel):
    def forward(self, data):
        # If .loc is accessed (for the first time) outside the plate,
        # then it will have empty shape ().
        loc = self.loc
        assert loc.shape == ()
        with pyro.plate("data", len(data)):
            pyro.sample("obs", dist.Normal(loc, 1), obs=data)

class LocalModel(NormalModel):
    def forward(self, data):
        with pyro.plate("data", len(data)):
            # If .loc is accessed (for the first time) inside the plate,
            # then it will be expanded by the plate to shape (plate.size,).
            loc = self.loc
            assert loc.shape == (len(data),)
            pyro.sample("obs", dist.Normal(loc, 1), obs=data)

data = torch.randn(10)
LocalModel()(data)
GlobalModel()(data)

如何创建复杂的嵌套 PyroModule

为了对上面的 BayesianLinear 模块进行推断,我们需要将其包装在一个带有似然的概率模型中;这个包装器也将是一个 PyroModule

[12]:
class Model(PyroModule):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.linear = BayesianLinear(in_size, out_size)  # this is a PyroModule
        self.obs_scale = PyroSample(dist.LogNormal(0, 1))

    def forward(self, input, output=None):
        obs_loc = self.linear(input)  # this samples linear.bias and linear.weight
        obs_scale = self.obs_scale    # this samples self.obs_scale
        with pyro.plate("instances", len(input)):
            return pyro.sample("obs", dist.Normal(obs_loc, obs_scale).to_event(1),
                               obs=output)

通常 nn.Module 可以使用简单的 PyTorch 优化器进行训练,而 Pyro 模型需要概率推断,例如使用 SVIAutoNormal guide。详情请参阅贝叶斯回归教程

[13]:
%%time
pyro.clear_param_store()
pyro.set_rng_seed(1)

model = Model(5, 2)
x = torch.randn(100, 5)
y = model(x)

guide = AutoNormal(model)
svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
for step in range(2 if smoke_test else 501):
    loss = svi.step(x, y) / y.numel()
    if step % 100 == 0:
        print("step {} loss = {:0.4g}".format(step, loss))
step 0 loss = 7.186
step 100 loss = 2.185
step 200 loss = 1.87
step 300 loss = 1.739
step 400 loss = 1.691
step 500 loss = 1.673
CPU times: user 2.35 s, sys: 24.8 ms, total: 2.38 s
Wall time: 2.39 s

PyroSample 语句也可能依赖于其他采样语句或参数。在这种情况下,prior 可以是依赖于 self 的可调用对象,而不是一个固定的分布。例如,考虑分层模型

[14]:
class Model(PyroModule):
    def __init__(self):
        super().__init__()
        self.dof = PyroSample(dist.Gamma(3, 1))
        self.loc = PyroSample(dist.Normal(0, 1))
        self.scale = PyroSample(lambda self: dist.InverseGamma(self.dof, 1))
        self.x = PyroSample(lambda self: dist.Normal(self.loc, self.scale))

    def forward(self):
        return self.x

Model()()
[14]:
tensor(0.5387)

命名如何工作

在上面的代码中,我们看到了嵌入在另一个 Model 中的 BayesianLinear 模型。两者都是 PyroModule。简单的 pyro.sample 语句需要名称字符串,而 PyroModule 属性会自动处理命名。让我们看看上面的 modelguide(因为 AutoNormal 也是一个 PyroModule)是如何工作的。

让我们跟踪模型和 guide 的执行。

[15]:
with poutine.trace() as tr:
    model(x)
for site in tr.trace.nodes.values():
    print(site["type"], site["name"], site["value"].shape)
sample linear.bias torch.Size([2])
sample linear.weight torch.Size([5, 2])
sample obs_scale torch.Size([])
sample instances torch.Size([100])
sample obs torch.Size([100, 2])

请注意,model.linear.bias 对应于 linear.bias 名称,对于 model.linear.weightmodel.obs_scale 属性也是如此。“instances” site 对应于 plate,“obs” site 对应于似然。接下来查看 guide

[16]:
with poutine.trace() as tr:
    guide(x)
for site in tr.trace.nodes.values():
    print(site["type"], site["name"], site["value"].shape)
param AutoNormal.locs.linear.bias torch.Size([2])
param AutoNormal.scales.linear.bias torch.Size([2])
sample linear.bias_unconstrained torch.Size([2])
sample linear.bias torch.Size([2])
param AutoNormal.locs.linear.weight torch.Size([5, 2])
param AutoNormal.scales.linear.weight torch.Size([5, 2])
sample linear.weight_unconstrained torch.Size([5, 2])
sample linear.weight torch.Size([5, 2])
param AutoNormal.locs.obs_scale torch.Size([])
param AutoNormal.scales.obs_scale torch.Size([])
sample obs_scale_unconstrained torch.Size([])
sample obs_scale torch.Size([])

我们看到 guide 学习了三个随机变量的后验:linear.biaslinear.weightobs_scale。对于每个变量,guide 学习了一对 (loc,scale) 参数,这些参数内部存储在嵌套的 PyroModule

class AutoNormal(...):
    def __init__(self, ...):
        self.locs = PyroModule()
        self.scales = PyroModule()
        ...

最后,AutoNormal 包含一个对应每个无约束潜 site 的 pyro.sample 语句,其后是一个 pyro.deterministic 语句,用于将无约束样本映射到受约束的后验样本。

⚠ 注意:避免重复名称

PyroModule 会自动为其属性命名,即使是嵌套在其他 PyroModule 中的深层属性。然而,在混合使用普通 nn.ModulePyroModule 时必须小心,因为 nn.Module 不支持自动 site 命名。

在单个模型(或 guide)中

如果只有一个 PyroModule,那么是安全的。

  class Model(nn.Module):        # not a PyroModule
      def __init__(self):
          self.x = PyroModule()
-         self.y = PyroModule()  # Could lead to name conflict.
+         self.y = nn.Module()  # Has no Pyro names, so avoids conflict.

如果只有两个 PyroModule,那么其中一个必须是另一个的属性。

class Model(PyroModule):
    def __init__(self):
       self.x = PyroModule()  # ok

如果您有两个互不为属性的 PyroModule,那么它们必须通过其他 PyroModule 的属性链接连接起来。这些可以是同级链接

- class Model(nn.Module):     # Could lead to name conflict.
+ class Model(PyroModule):    # Ensures names are unique.
      def __init__(self):
          self.x = PyroModule()
          self.y = PyroModule()

或祖先链接

  class Model(PyroModule):
      def __init__(self):
-         self.x = nn.Module()    # Could lead to name conflict.
+         self.x = PyroModule()   # Ensures y is conected to root Model.
          self.x.y = PyroModule()

有时您可能希望将 (model,guide) 对存储在单个 nn.Module 中,例如为了从 C++ 提供服务。在这种情况下,将它们作为容器 nn.Module 的属性是安全的,但该容器不应该是一个 PyroModule

class Container(nn.Module):            # This cannot be a PyroModule.
    def __init__(self, model, guide):  # These may be PyroModules.
        super().__init__()
        self.model = model
        self.guide = guide
    # This is a typical trace-replay pattern seen in model serving.
    def forward(self, data):
        tr = poutine.trace(self.guide).get_trace(data)
        return poutine.replay(model, tr)(data)
[ ]: