交互式后验预测检查¶
本 Notebook 演示了如何使用 ipywidgets 交互式地检查模型先验。
⚠️ 本 Notebook 旨在交互式运行。请在本地运行或在 Colab 中打开。
贝叶斯工作流程的第一步是创建模型。第二步是检查模型的先验样本。本 Notebook 展示了如何在可视化模型输出的同时,交互式地检查先验样本并调整顶层先验分布的参数。
摘要¶
将模型封装在一个绘图函数中。
使用 ipywidgets.interact() 为先验的每个参数创建滑块。
对于计算开销大的模型,请使用 Resampler。
[1]:
!pip install -q pyro-ppl # for colab
[2]:
import os
from ipywidgets import interact, FloatSlider
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from pyro.infer.resampler import Resampler
assert pyro.__version__.startswith('1.9.1')
smoke_test = ('CI' in os.environ) # for CI testing only
[3]:
def model(T: int = 1000, data=None):
# Sample parameters from the prior.
df = pyro.sample("df", dist.LogNormal(0, 1))
p_scale = pyro.sample("p_scale", dist.LogNormal(0, 1)) # process noise
m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1)) # measurement noise
# Simulate a time series.
with pyro.plate("dt", T):
process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
trend = pyro.deterministic("trend", process_noise.cumsum(-1))
with pyro.plate("t", T):
return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[4]:
def plot_trajectory(df=1.0, p_scale=1.0, m_scale=1.0):
pyro.set_rng_seed(12345)
data = {
"df": torch.as_tensor(df),
"p_scale": torch.as_tensor(p_scale),
"m_scale": torch.as_tensor(m_scale),
}
trajectory = poutine.condition(model, data)()
plt.figure(figsize=(8, 4)).patch.set_color("white")
plt.plot(trajectory)
plt.xlabel("time")
plt.ylabel("obs")
现在我们可以查看顶层隐变量特定取值下的模型轨迹是什么样的。
[5]:
interact(
plot_trajectory,
df=FloatSlider(value=1.0, min=0.01, max=10.0),
p_scale=FloatSlider(value=0.1, min=0.01, max=1.0),
m_scale=FloatSlider(value=1.0, min=0.01, max=10.0),
);
但是为了调整先验的参数,我们希望查看一组轨迹,其中每个轨迹的顶层参数都是从当前先验中采样的。让我们重写模型,以便可以输入先验参数。
[6]:
def model2(T: int = 1000, data=None, df0=0, df1=1, p0=0, p1=1, m0=0, m1=1):
# Sample parameters from the prior.
df = pyro.sample("df", dist.LogNormal(df0, df1))
p_scale = pyro.sample("p_scale", dist.LogNormal(p0, p1)) # process noise
m_scale = pyro.sample("m_scale", dist.LogNormal(m0, m1)) # measurement noise
# Simulate a time series.
with pyro.plate("dt", T):
process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
trend = pyro.deterministic("trend", process_noise.cumsum(-1))
with pyro.plate("t", T):
return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[7]:
def plot_trajectories(**kwargs):
pyro.set_rng_seed(12345)
with pyro.plate("trajectories", 20, dim=-2):
trajectories = model2(**kwargs)
plt.figure(figsize=(8, 5)).patch.set_color("white")
plt.plot(trajectories.T)
plt.xlabel("time")
plt.ylabel("obs")
[8]:
interact(
plot_trajectories,
df0=FloatSlider(value=0.0, min=-5, max=5),
df1=FloatSlider(value=1.0, min=0.1, max=10),
p0=FloatSlider(value=0.0, min=-5, max=5),
p1=FloatSlider(value=1.0, min=0.1, max=10),
m0=FloatSlider(value=0.0, min=-5, max=5),
m1=FloatSlider(value=1.0, min=0.1, max=10),
);
呀!看起来我们最初的先验生成了非常奇怪的轨迹,但我们可以通过滑动来找到更好的先验。尝试增加 df0
。
重采样器¶
对于计算开销更大的模拟,每次更改时采样可能太慢而无法进行交互式生成样本。作为一个计算技巧,我们可以从一个扩散分布中一次抽取许多样本,然后从修改后的分布中对它们进行重采样——前提是我们进行重要性采样或重采样。Pyro 提供了一个重要性 Resampler,用于帮助交互式可视化开销大的模型。
我们将从原始模型开始,创建一个方法来构建具有给定先验的参数化局部模型。这些局部模型就是我们模型的上半部分,即顶层参数。
[9]:
def make_partial_model(df0, df1, p0, p1, m0, m1):
def partial_model():
# Sample parameters from the prior.
pyro.sample("df", dist.LogNormal(df0, df1))
pyro.sample("p_scale", dist.LogNormal(p0, p1)) # process noise
pyro.sample("m_scale", dist.LogNormal(m0, m1)) # measurement noise
return partial_model
接下来,我们将使用一个覆盖大部分所需参数空间的扩散引导函数来初始化 Resampler
。这在实际模拟中可能开销很大,所以你可能希望让它运行通宵。
[10]:
%%time
partial_guide = make_partial_model(0, 10, 0, 10, 0, 10)
resampler = Resampler(partial_guide, model, num_guide_samples=10000)
CPU times: user 940 ms, sys: 146 ms, total: 1.09 s
Wall time: 934 ms
Resampler.sample()
方法接受一个修改过的局部模型。
[11]:
def plot_resampled(df0, df1, p0, p1, m0, m1):
partial_model = make_partial_model(df0, df1, p0, p1, m0, m1)
samples = resampler.sample(partial_model, num_samples=20)
trajectories = samples["obs"]
plt.figure(figsize=(8, 5)).patch.set_color("white")
plt.plot(trajectories.T)
plt.xlabel("time")
plt.ylabel("obs")
[12]:
interact(
plot_resampled,
df0=FloatSlider(value=0.0, min=-5, max=5),
df1=FloatSlider(value=1.0, min=0.1, max=10),
p0=FloatSlider(value=0.0, min=-5, max=5),
p1=FloatSlider(value=1.0, min=0.1, max=10),
m0=FloatSlider(value=0.0, min=-5, max=5),
m1=FloatSlider(value=1.0, min=0.1, max=10),
);
确定了好的先验参数后,我们可以将它们硬编码到模型中。
[13]:
def model(T: int = 1000, data=None):
df = pyro.sample("df", dist.LogNormal(4, 1)) # <-- changed 0 to 4
p_scale = pyro.sample("p_scale", dist.LogNormal(1, 1)) # <-- changed 0 to 1
m_scale = pyro.sample("m_scale", dist.LogNormal(0, 1))
with pyro.plate("dt", T):
process_noise = pyro.sample("process_noise", dist.StudentT(df, 0, p_scale))
trend = pyro.deterministic("trend", process_noise.cumsum(-1))
with pyro.plate("t", T):
return pyro.sample("obs", dist.Normal(trend, m_scale), obs=data)
[ ]: