交互式后验预测检查

本 Notebook 演示了如何使用 ipywidgets 交互式地检查模型先验。

⚠️ 本 Notebook 旨在交互式运行。请在本地运行或在 Colab 中打开

贝叶斯工作流程的第一步是创建模型。第二步是检查模型的先验样本。本 Notebook 展示了如何在可视化模型输出的同时,交互式地检查先验样本并调整顶层先验分布的参数。

摘要

  • 将模型封装在一个绘图函数中。

  • 使用 ipywidgets.interact() 为先验的每个参数创建滑块。

  • 对于计算开销大的模型,请使用 Resampler

[1]:
!pip install -q pyro-ppl  # for colab
[2]:
import os
from ipywidgets import interact, FloatSlider
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.resampler import Resampler

assert pyro.__version__.startswith('1.9.1')
smoke_test = ('CI' in os.environ)  # for CI testing only
[3]:
def model(T: int = 1000, data=None):
    # Sample parameters from the prior.
    df = pyro.sample("df", dist.LogNormal(0, 1))
    p_scale = pyro.sample("p_scale", dist.LogNormal(0, 1))  # process noise
    m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1))  # measurement noise

    # Simulate a time series.
    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[4]:
def plot_trajectory(df=1.0, p_scale=1.0, m_scale=1.0):
    pyro.set_rng_seed(12345)
    data = {
        "df": torch.as_tensor(df),
        "p_scale": torch.as_tensor(p_scale),
        "m_scale": torch.as_tensor(m_scale),
    }
    trajectory = poutine.condition(model, data)()
    plt.figure(figsize=(8, 4)).patch.set_color("white")
    plt.plot(trajectory)
    plt.xlabel("time")
    plt.ylabel("obs")

现在我们可以查看顶层隐变量特定取值下的模型轨迹是什么样的。

[5]:
interact(
    plot_trajectory,
    df=FloatSlider(value=1.0, min=0.01, max=10.0),
    p_scale=FloatSlider(value=0.1, min=0.01, max=1.0),
    m_scale=FloatSlider(value=1.0, min=0.01, max=10.0),
);

但是为了调整先验的参数,我们希望查看一组轨迹,其中每个轨迹的顶层参数都是从当前先验中采样的。让我们重写模型,以便可以输入先验参数。

[6]:
def model2(T: int = 1000, data=None, df0=0, df1=1, p0=0, p1=1, m0=0, m1=1):
    # Sample parameters from the prior.
    df = pyro.sample("df", dist.LogNormal(df0, df1))
    p_scale = pyro.sample("p_scale", dist.LogNormal(p0, p1))  # process noise
    m_scale = pyro.sample("m_scale", dist.LogNormal(m0, m1))  # measurement noise

    # Simulate a time series.
    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[7]:
def plot_trajectories(**kwargs):
    pyro.set_rng_seed(12345)
    with pyro.plate("trajectories", 20, dim=-2):
        trajectories = model2(**kwargs)
    plt.figure(figsize=(8, 5)).patch.set_color("white")
    plt.plot(trajectories.T)
    plt.xlabel("time")
    plt.ylabel("obs")
[8]:
interact(
    plot_trajectories,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
);

呀!看起来我们最初的先验生成了非常奇怪的轨迹,但我们可以通过滑动来找到更好的先验。尝试增加 df0

重采样器

对于计算开销更大的模拟,每次更改时采样可能太慢而无法进行交互式生成样本。作为一个计算技巧,我们可以从一个扩散分布中一次抽取许多样本,然后从修改后的分布中对它们进行重采样——前提是我们进行重要性采样或重采样。Pyro 提供了一个重要性 Resampler,用于帮助交互式可视化开销大的模型。

我们将从原始模型开始,创建一个方法来构建具有给定先验的参数化局部模型。这些局部模型就是我们模型的上半部分,即顶层参数。

[9]:
def make_partial_model(df0, df1, p0, p1, m0, m1):
    def partial_model():
        # Sample parameters from the prior.
        pyro.sample("df", dist.LogNormal(df0, df1))
        pyro.sample("p_scale", dist.LogNormal(p0, p1))  # process noise
        pyro.sample("m_scale", dist.LogNormal(m0, m1))  # measurement noise
    return partial_model

接下来,我们将使用一个覆盖大部分所需参数空间的扩散引导函数来初始化 Resampler。这在实际模拟中可能开销很大,所以你可能希望让它运行通宵。

[10]:
%%time
partial_guide = make_partial_model(0, 10, 0, 10, 0, 10)
resampler = Resampler(partial_guide, model, num_guide_samples=10000)
CPU times: user 940 ms, sys: 146 ms, total: 1.09 s
Wall time: 934 ms

Resampler.sample() 方法接受一个修改过的局部模型。

[11]:
def plot_resampled(df0, df1, p0, p1, m0, m1):
    partial_model = make_partial_model(df0, df1, p0, p1, m0, m1)
    samples = resampler.sample(partial_model, num_samples=20)
    trajectories = samples["obs"]
    plt.figure(figsize=(8, 5)).patch.set_color("white")
    plt.plot(trajectories.T)
    plt.xlabel("time")
    plt.ylabel("obs")
[12]:
interact(
    plot_resampled,
    df0=FloatSlider(value=0.0, min=-5, max=5),
    df1=FloatSlider(value=1.0, min=0.1, max=10),
    p0=FloatSlider(value=0.0, min=-5, max=5),
    p1=FloatSlider(value=1.0, min=0.1, max=10),
    m0=FloatSlider(value=0.0, min=-5, max=5),
    m1=FloatSlider(value=1.0, min=0.1, max=10),
);

确定了好的先验参数后,我们可以将它们硬编码到模型中。

[13]:
def model(T: int = 1000, data=None):
    df = pyro.sample("df", dist.LogNormal(4, 1))  # <-- changed 0 to 4
    p_scale = pyro.sample("p_scale", dist.LogNormal(1, 1))  # <-- changed 0 to 1
    m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1))

    with pyro.plate("dt", T):
        process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
    trend = pyro.deterministic("trend", process_noise.cumsum(-1))
    with pyro.plate("t", T):
        return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[ ]: