编译序列重要性采样

编译序列重要性采样 [1],或称推断编译,是一种通过学习重要性采样的提议分布来分摊推断计算成本的技术。

提议分布被学习以最小化模型与 guide 之间的 KL 散度,\(\rm{KL}\!\left( p({\bf z} | {\bf x}) \lVert q_{\phi, x}({\bf z}) \right)\)。这与变分推断不同,变分推断会最小化 \(\rm{KL}\!\left( q_{\phi, x}({\bf z}) \lVert p({\bf z} | {\bf x}) \right)\)。使用前一种损失鼓励近似提议分布比真实的后验分布更宽(覆盖质量),而变分推断通常学习更窄的近似分布(寻找模式)。重要性采样的 guide 通常期望其尾部比模型更重(参见这个 stackexchange 问题)。因此,推断编译损失通常更适合于编译重要性采样的 guide。

CSIS 的另一个优点是,与许多类型的变分推断不同,它不要求模型是可微的。这使得它可以用于对任意复杂的程序进行推断(例如,Captcha 渲染器 [1])。

本示例展示了如何使用 CSIS 来加速对已知解析解的简单问题的推断。

[1]:
import torch
import torch.nn as nn
import torch.functional as F

import pyro
import pyro.distributions as dist
import pyro.infer
import pyro.optim

import os
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

指定模型:

模型以与任何 Pyro 模型相同的方式指定,但必须使用关键字参数 observations 来输入一个字典,其中每个观察值作为键。由于推断编译涉及学习对任何观察值进行推断,因此字典中的值是什么并不重要。此处使用 0

[2]:
def model(prior_mean, observations={"x1": 0, "x2": 0}):
    x = pyro.sample("z", dist.Normal(prior_mean, torch.tensor(5**0.5)))
    y1 = pyro.sample("x1", dist.Normal(x, torch.tensor(2**0.5)), obs=observations["x1"])
    y2 = pyro.sample("x2", dist.Normal(x, torch.tensor(2**0.5)), obs=observations["x2"])
    return x

指定 guide:

guide 将被训练(又称编译)以使用观察值来为每个未条件化的 sample 语句创建提议分布。在论文 [1] 中,神经网络结构是为任何模型自动生成的。然而,在 Pyro 的实现中,用户必须指定特定任务的 guide 程序结构。与任何 Pyro guide 函数一样,它应该具有与模型相同的调用签名。它还必须遇到与模型相同的未观测到的 sample 语句。为了使 guide 程序能够训练出好的提议分布,sample 语句处的分布应该依赖于 observations 中的值。在此示例中,使用前馈神经网络将观察值映射到隐变量的提议分布。

在运行 guide 函数时调用 pyro.module,以便优化器在训练期间可以找到 guide 参数。

[3]:
class Guide(nn.Module):
    def __init__(self):
        super().__init__()
        self.neural_net = nn.Sequential(
            nn.Linear(2, 10),
            nn.ReLU(),
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 10),
            nn.ReLU(),
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 2))

    def forward(self, prior_mean, observations={"x1": 0, "x2": 0}):
        pyro.module("guide", self)
        x1 = observations["x1"]
        x2 = observations["x2"]
        v = torch.cat((x1.view(1, 1), x2.view(1, 1)), 1)
        v = self.neural_net(v)
        mean = v[0, 0]
        std = v[0, 1].exp()
        pyro.sample("z", dist.Normal(mean, std))

guide = Guide()

现在创建一个 CSIS 实例:

该对象使用模型、guide、用于训练 guide 的 PyTorch 优化器以及在执行推断时抽取的带权重重要性样本数量进行初始化。guide 将针对模型/guide 参数 prior_mean 的特定值进行优化,因此我们在整个训练和推断过程中都使用此处设置的值。

[4]:
optimiser = pyro.optim.Adam({'lr': 1e-3})
csis = pyro.infer.CSIS(model, guide, optimiser, num_inference_samples=50)
prior_mean = torch.tensor(1.)

现在我们“编译”该实例以对此模型执行推断:

传递给 csis.step 的参数在运行模型和 guide 以评估损失时传递给它们。

[5]:
for step in range(n_steps):
    csis.step(prior_mean)

现在通过重要性采样执行推断:

编译好的 guide 程序现在应该能够为 z 提出一个近似后验分布 \(p(z | x_1, x_2)\) 的分布,适用于任何 \(x_1, x_2\)。再次输入相同的 prior_mean,以及 observations 内部的观察值。

[6]:
posterior = csis.run(prior_mean,
                     observations={"x1": torch.tensor(8.),
                                   "x2": torch.tensor(9.)})
marginal = pyro.infer.EmpiricalMarginal(posterior, "z")

现在绘制结果并与重要性采样进行比较:

我们观察到 \(x_1 = 8\)\(x_2 = 9\)。使用 CSIS 抽取 50 个样本进行推断,并使用来自先验的重要性采样抽取 50 个样本。然后绘制生成的后验分布近似图,并与解析后验进行比较。

[7]:
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt

with torch.no_grad():
    # Draw samples from empirical marginal for plotting
    csis_samples = torch.stack([marginal() for _ in range(1000)])

    # Calculate empirical marginal with importance sampling
    is_posterior = pyro.infer.Importance(model, num_samples=50).run(
        prior_mean, observations={"x1": torch.tensor(8.),
                                  "x2": torch.tensor(9.)})
    is_marginal = pyro.infer.EmpiricalMarginal(is_posterior, "z")
    is_samples = torch.stack([is_marginal() for _ in range(1000)])

# Calculate true prior and posterior over z
true_posterior_z = torch.arange(-10, 10, 0.05)
true_posterior_p = dist.Normal(7.25, (5/6)**0.5).log_prob(true_posterior_z).exp()
prior_z = true_posterior_z
prior_p = dist.Normal(1., 5**0.5).log_prob(true_posterior_z).exp()

plt.rcParams['figure.figsize'] = [30, 15]
plt.rcParams.update({'font.size': 30})
fig, ax = plt.subplots()
plt.plot(prior_z, prior_p, 'k--', label='Prior')
plt.plot(true_posterior_z, true_posterior_p, color='k', label='Analytic Posterior')
plt.hist(csis_samples.numpy(), range=(-10, 10), bins=100, color='r', density=1,
         label="Inference Compilation")
plt.hist(is_samples.numpy(), range=(-10, 10), bins=100, color='b', density=1,
         label="Importance Sampling")
plt.xlim(-8, 10)
plt.ylim(0, 5)
plt.xlabel("z")
plt.ylabel("Estimated Posterior Probability Density")
plt.legend()
plt.show()
_images/csis_13_0.png

使用 \(x_1 = 8\)\(x_2 = 9\) 得到的后验分布远离先验,因此使用先验作为重要性采样的 guide 是低效的,有效样本大小非常小。通过首先学习合适的 guide 函数,CSIS 具有与真实后验更匹配的提议分布。这使得抽取的样本更能覆盖真实后验,并获得更大的有效样本大小,如上图所示。

有关推断编译的其他示例,请参阅 [1] 或 https://github.com/probprog/anglican-infcomp-examples

参考文献

[1] Inference compilation and universal probabilistic programming,     Tuan Anh Le, Atilim Gunes Baydin, and Frank Wood