将 PyTorch JIT 编译器与 Pyro 结合使用

本教程展示了如何在 Pyro 模型中使用 PyTorch JIT 编译器

摘要:

  • 您可以在 Pyro 模型中使用编译后的函数。

  • 您不能在编译后的函数中使用 Pyro 原语。

  • 如果您的模型结构是静态的,您可以使用 Jit* 版本的 ELBO 算法,例如:

    - Trace_ELBO()
    + JitTrace_ELBO()
    
  • HMCNUTS 类接受 jit_compile=True 关键字参数。

  • 模型应将所有张量作为 *args 输入,将所有非张量作为 **kwargs 输入。

  • **kwargs 的每个不同值都会触发一次单独的编译。

  • 使用 **kwargs 指定结构中的所有变化(例如时间序列长度)。

  • 要在安全的代码块中忽略 JIT 警告,请使用 with pyro.util.ignore_jit_warnings():

  • 要在 HMCNUTS 中忽略所有 JIT 警告,请传递 ignore_jit_warnings=True

目录

[1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, JitTrace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO, SVI
from pyro.infer.mcmc import MCMC, NUTS
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')

简介

PyTorch 1.0 包含一个 JIT 编译器 以加速模型。您可以将编译视为一种“静态模式”,而 PyTorch 通常在“即时模式”下运行。

Pyro 通过两种方式支持 JIT 编译器。首先,您可以在 Pyro 模型中使用编译后的函数(但这些函数不能包含 Pyro 原语)。其次,您可以使用 Pyro 的 JIT 推断算法来编译整个推断步骤;在静态模型中,这可以减少 Pyro 模型的 Python 开销并加速推断。

本教程的其余部分重点介绍 Pyro 的 JIT 推断算法:JitTrace_ELBOJitTraceGraph_ELBOJitTraceEnum_ELBOJitMeanField_ELBOHMC(jit_compile=True)NUTS(jit_compile=True)。欲了解更多信息,请参阅 examples/ 目录,其中大多数示例都包含一个 --jit 选项以在编译模式下运行。

一个简单的模型

让我们从一个简单的高斯模型和一个 自动引导 开始。

[2]:
def model(data):
    loc = pyro.sample("loc", dist.Normal(0., 10.))
    scale = pyro.sample("scale", dist.LogNormal(0., 3.))
    with pyro.plate("data", data.size(0)):
        pyro.sample("obs", dist.Normal(loc, scale), obs=data)

guide = AutoDiagonalNormal(model)

data = dist.Normal(0.5, 2.).sample((100,))

首先,像往常一样使用 SVI 对象和 Trace_ELBO 运行。

[3]:
%%time
pyro.clear_param_store()
elbo = Trace_ELBO()
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(2 if smoke_test else 1000):
    svi.step(data)
CPU times: user 2.71 s, sys: 31.4 ms, total: 2.74 s
Wall time: 2.76 s

接下来,要使用 JIT 编译的推断运行,我们只需替换

- elbo = Trace_ELBO()
+ elbo = JitTrace_ELBO()

另外请注意,AutoDiagonalNormal 引导在其第一次调用时表现稍有不同(它运行模型以生成原型跟踪),我们不希望在编译时记录这种预热行为。因此,我们调用 guide(data) 一次进行初始化,然后运行编译后的 SVI,

[4]:
%%time
pyro.clear_param_store()

guide(data)  # Do any lazy initialization before compiling.

elbo = JitTrace_ELBO()
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(2 if smoke_test else 1000):
    svi.step(data)
CPU times: user 1.1 s, sys: 30.4 ms, total: 1.13 s
Wall time: 1.16 s

请注意,对于这个小型模型,我们获得了超过 2 倍的速度提升。

现在让我们使用相同的模型,但我们将转而使用 MCMC 从模型的后验生成样本。我们将使用无 U 形转弯 (NUTS) 采样器。

[5]:
%%time
nuts_kernel = NUTS(model)
pyro.set_rng_seed(1)
mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)
CPU times: user 4.61 s, sys: 101 ms, total: 4.71 s
Wall time: 4.7 s

我们可以使用 jit_compile=True 参数编译 NUTS 核中的势能计算。我们还通过使用 ignore_jit_warnings=True 来消除模型中存在张量常量导致的 JIT 警告。

[6]:
%%time
nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
pyro.set_rng_seed(1)
mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)
CPU times: user 2.04 s, sys: 74.1 ms, total: 2.11 s
Wall time: 2.09 s

我们注意到启用 JIT 编译后,采样吞吐量显著增加。

变化的结构

时间序列模型通常在具有不同长度的多个时间序列数据集上运行。为了适应这种变化的结构,Pyro 要求模型将所有模型输入分为张量和非张量。\(^\dagger\)

  • 非张量输入应作为 **kwargs 传递给模型和引导。这些可以决定模型结构,以便为传递的 **kwargs 的每个值编译一个模型。

  • 张量输入应作为 *args 传递。这些不得决定模型结构。但是 len(args) 可能会决定模型结构(例如在半监督模型中使用)。

为了用时间序列模型来说明这一点,我们将观察序列作为张量 arg 传入,并将序列长度作为非张量 kwarg 传入

[5]:
def model(sequence, num_sequences, length, state_dim=16):
    # This is a Gaussian HMM model.
    with pyro.plate("states", state_dim):
        trans = pyro.sample("trans", dist.Dirichlet(0.5 * torch.ones(state_dim)))
        emit_loc = pyro.sample("emit_loc", dist.Normal(0., 10.))
    emit_scale = pyro.sample("emit_scale", dist.LogNormal(0., 3.))

    # We're doing manual data subsampling, so we need to scale to actual data size.
    with poutine.scale(scale=num_sequences):
        # We'll use enumeration inference over the hidden x.
        x = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample("x_{}".format(t), dist.Categorical(trans[x]),
                            infer={"enumerate": "parallel"})
            pyro.sample("y_{}".format(t), dist.Normal(emit_loc[x], emit_scale),
                        obs=sequence[t])

guide = AutoDiagonalNormal(poutine.block(model, expose=["trans", "emit_scale", "emit_loc"]))

# This is fake data of different lengths.
lengths = [24] * 50 + [48] * 20 + [72] * 5
sequences = [torch.randn(length) for length in lengths]

现在让我们像往常一样运行 SVI。

[6]:
%%time
pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(1 if smoke_test else 10):
    for sequence in sequences:
        svi.step(sequence,                                            # tensor args
                 num_sequences=len(sequences), length=len(sequence))  # non-tensor args
CPU times: user 52.4 s, sys: 270 ms, total: 52.7 s
Wall time: 52.8 s

我们将再次简单地换入 Jit* 实现

- elbo = TraceEnum_ELBO(max_plate_nesting=1)
+ elbo = JitTraceEnum_ELBO(max_plate_nesting=1)

请注意,我们手动指定了 max_plate_nesting 参数。通常 Pyro 可以在第一次调用时运行模型自动确定此值;但是,为了避免在第一步运行编译器时执行此额外工作,我们手动传入此值。

[7]:
%%time
pyro.clear_param_store()

# Do any lazy initialization before compiling.
guide(sequences[0], num_sequences=len(sequences), length=len(sequences[0]))

elbo = JitTraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(1 if smoke_test else 10):
    for sequence in sequences:
        svi.step(sequence,                                            # tensor args
                 num_sequences=len(sequences), length=len(sequence))  # non-tensor args
CPU times: user 21.9 s, sys: 201 ms, total: 22.1 s
Wall time: 22.2 s

我们再次看到了超过 2 倍的速度提升。请注意,由于有三种不同的序列长度,编译被触发了三次。

\(^\dagger\) 请注意,本节仅对 SVI 有效,而 HMC/NUTS 假定模型参数固定。