将 PyTorch JIT 编译器与 Pyro 结合使用¶
本教程展示了如何在 Pyro 模型中使用 PyTorch JIT 编译器。
摘要:¶
您可以在 Pyro 模型中使用编译后的函数。
您不能在编译后的函数中使用 Pyro 原语。
如果您的模型结构是静态的,您可以使用
Jit*
版本的ELBO
算法,例如:- Trace_ELBO() + JitTrace_ELBO()
模型应将所有张量作为
*args
输入,将所有非张量作为**kwargs
输入。**kwargs
的每个不同值都会触发一次单独的编译。使用
**kwargs
指定结构中的所有变化(例如时间序列长度)。要在安全的代码块中忽略 JIT 警告,请使用
with pyro.util.ignore_jit_warnings():
。要在
HMC
或NUTS
中忽略所有 JIT 警告,请传递ignore_jit_warnings=True
。
目录¶
[1]:
import os
import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints
from pyro import poutine
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, JitTrace_ELBO, TraceEnum_ELBO, JitTraceEnum_ELBO, SVI
from pyro.infer.mcmc import MCMC, NUTS
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.optim import Adam
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')
简介¶
PyTorch 1.0 包含一个 JIT 编译器 以加速模型。您可以将编译视为一种“静态模式”,而 PyTorch 通常在“即时模式”下运行。
Pyro 通过两种方式支持 JIT 编译器。首先,您可以在 Pyro 模型中使用编译后的函数(但这些函数不能包含 Pyro 原语)。其次,您可以使用 Pyro 的 JIT 推断算法来编译整个推断步骤;在静态模型中,这可以减少 Pyro 模型的 Python 开销并加速推断。
本教程的其余部分重点介绍 Pyro 的 JIT 推断算法:JitTrace_ELBO、JitTraceGraph_ELBO、JitTraceEnum_ELBO、JitMeanField_ELBO、HMC(jit_compile=True) 和 NUTS(jit_compile=True)。欲了解更多信息,请参阅 examples/ 目录,其中大多数示例都包含一个 --jit
选项以在编译模式下运行。
一个简单的模型¶
让我们从一个简单的高斯模型和一个 自动引导 开始。
[2]:
def model(data):
loc = pyro.sample("loc", dist.Normal(0., 10.))
scale = pyro.sample("scale", dist.LogNormal(0., 3.))
with pyro.plate("data", data.size(0)):
pyro.sample("obs", dist.Normal(loc, scale), obs=data)
guide = AutoDiagonalNormal(model)
data = dist.Normal(0.5, 2.).sample((100,))
首先,像往常一样使用 SVI 对象和 Trace_ELBO
运行。
[3]:
%%time
pyro.clear_param_store()
elbo = Trace_ELBO()
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(2 if smoke_test else 1000):
svi.step(data)
CPU times: user 2.71 s, sys: 31.4 ms, total: 2.74 s
Wall time: 2.76 s
接下来,要使用 JIT 编译的推断运行,我们只需替换
- elbo = Trace_ELBO()
+ elbo = JitTrace_ELBO()
另外请注意,AutoDiagonalNormal
引导在其第一次调用时表现稍有不同(它运行模型以生成原型跟踪),我们不希望在编译时记录这种预热行为。因此,我们调用 guide(data)
一次进行初始化,然后运行编译后的 SVI,
[4]:
%%time
pyro.clear_param_store()
guide(data) # Do any lazy initialization before compiling.
elbo = JitTrace_ELBO()
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(2 if smoke_test else 1000):
svi.step(data)
CPU times: user 1.1 s, sys: 30.4 ms, total: 1.13 s
Wall time: 1.16 s
请注意,对于这个小型模型,我们获得了超过 2 倍的速度提升。
现在让我们使用相同的模型,但我们将转而使用 MCMC 从模型的后验生成样本。我们将使用无 U 形转弯 (NUTS) 采样器。
[5]:
%%time
nuts_kernel = NUTS(model)
pyro.set_rng_seed(1)
mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)
CPU times: user 4.61 s, sys: 101 ms, total: 4.71 s
Wall time: 4.7 s
我们可以使用 jit_compile=True
参数编译 NUTS 核中的势能计算。我们还通过使用 ignore_jit_warnings=True
来消除模型中存在张量常量导致的 JIT 警告。
[6]:
%%time
nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
pyro.set_rng_seed(1)
mcmc_run = MCMC(nuts_kernel, num_samples=100).run(data)
CPU times: user 2.04 s, sys: 74.1 ms, total: 2.11 s
Wall time: 2.09 s
我们注意到启用 JIT 编译后,采样吞吐量显著增加。
变化的结构¶
时间序列模型通常在具有不同长度的多个时间序列数据集上运行。为了适应这种变化的结构,Pyro 要求模型将所有模型输入分为张量和非张量。\(^\dagger\)
非张量输入应作为
**kwargs
传递给模型和引导。这些可以决定模型结构,以便为传递的**kwargs
的每个值编译一个模型。张量输入应作为
*args
传递。这些不得决定模型结构。但是len(args)
可能会决定模型结构(例如在半监督模型中使用)。
为了用时间序列模型来说明这一点,我们将观察序列作为张量 arg
传入,并将序列长度作为非张量 kwarg
传入
[5]:
def model(sequence, num_sequences, length, state_dim=16):
# This is a Gaussian HMM model.
with pyro.plate("states", state_dim):
trans = pyro.sample("trans", dist.Dirichlet(0.5 * torch.ones(state_dim)))
emit_loc = pyro.sample("emit_loc", dist.Normal(0., 10.))
emit_scale = pyro.sample("emit_scale", dist.LogNormal(0., 3.))
# We're doing manual data subsampling, so we need to scale to actual data size.
with poutine.scale(scale=num_sequences):
# We'll use enumeration inference over the hidden x.
x = 0
for t in pyro.markov(range(length)):
x = pyro.sample("x_{}".format(t), dist.Categorical(trans[x]),
infer={"enumerate": "parallel"})
pyro.sample("y_{}".format(t), dist.Normal(emit_loc[x], emit_scale),
obs=sequence[t])
guide = AutoDiagonalNormal(poutine.block(model, expose=["trans", "emit_scale", "emit_loc"]))
# This is fake data of different lengths.
lengths = [24] * 50 + [48] * 20 + [72] * 5
sequences = [torch.randn(length) for length in lengths]
现在让我们像往常一样运行 SVI。
[6]:
%%time
pyro.clear_param_store()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(1 if smoke_test else 10):
for sequence in sequences:
svi.step(sequence, # tensor args
num_sequences=len(sequences), length=len(sequence)) # non-tensor args
CPU times: user 52.4 s, sys: 270 ms, total: 52.7 s
Wall time: 52.8 s
我们将再次简单地换入 Jit*
实现
- elbo = TraceEnum_ELBO(max_plate_nesting=1)
+ elbo = JitTraceEnum_ELBO(max_plate_nesting=1)
请注意,我们手动指定了 max_plate_nesting
参数。通常 Pyro 可以在第一次调用时运行模型自动确定此值;但是,为了避免在第一步运行编译器时执行此额外工作,我们手动传入此值。
[7]:
%%time
pyro.clear_param_store()
# Do any lazy initialization before compiling.
guide(sequences[0], num_sequences=len(sequences), length=len(sequences[0]))
elbo = JitTraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, Adam({'lr': 0.01}), elbo)
for i in range(1 if smoke_test else 10):
for sequence in sequences:
svi.step(sequence, # tensor args
num_sequences=len(sequences), length=len(sequence)) # non-tensor args
CPU times: user 21.9 s, sys: 201 ms, total: 22.1 s
Wall time: 22.2 s
我们再次看到了超过 2 倍的速度提升。请注意,由于有三种不同的序列长度,编译被触发了三次。
\(^\dagger\) 请注意,本节仅对 SVI 有效,而 HMC/NUTS 假定模型参数固定。