高斯混合模型

本教程通过混合模型的示例演示如何在 Pyro 中边缘化离散潜变量。我们将重点关注并行枚举的机制,通过在微小的 5 点数据集上训练一个简单的 1-D 高斯模型来保持模型简单。关于并行枚举的更广泛介绍,请参阅枚举教程

目录

[24]:
import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
from matplotlib import pyplot

%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

smoke_test = "CI" in os.environ
assert pyro.__version__.startswith('1.9.1')

概述

Pyro 的 TraceEnum_ELBO 可以自动边缘化引导和模型中的变量。在枚举引导变量时,Pyro 可以按顺序枚举(如果变量决定下游控制流,这很有用),或者通过分配新的张量维度并使用非标准评估在变量的采样点创建可能值张量来并行枚举。这些非标准值随后在模型中重播。在枚举模型中的变量时,这些变量必须并行枚举,并且不得出现在引导中。从数学上讲,引导侧枚举仅通过枚举所有值来降低随机 ELBO 的方差,而模型侧枚举通过精确边缘化变量来避免应用 Jensen 不等式。

这是我们微小的数据集。它有五个点。

[25]:
data = torch.tensor([0.0, 1.0, 10.0, 11.0, 12.0])

训练 MAP 估计器

让我们从学习给定先验和数据的模型参数 weightslocsscale 开始。我们将使用 AutoDelta 引导(以其 delta 分布命名)来学习这些参数的点估计。我们的模型将学习全局混合权重、每个混合分量的位置以及两个分量共有的共享尺度。在推断过程中,TraceEnum_ELBO 将边缘化数据点到聚类的分配。

[26]:
K = 2  # Fixed number of components.


@config_enumerate
def model(data):
    # Global variables.
    weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample("scale", dist.LogNormal(0.0, 2.0))
    with pyro.plate("components", K):
        locs = pyro.sample("locs", dist.Normal(0.0, 10.0))

    with pyro.plate("data", len(data)):
        # Local variables.
        assignment = pyro.sample("assignment", dist.Categorical(weights))
        pyro.sample("obs", dist.Normal(locs[assignment], scale), obs=data)

为了使用这对 (model,guide) 运行推断,我们使用 Pyro 的 config_enumerate() 处理器在每次迭代中枚举所有分配。由于我们将批量 Categorical 分配包装在 pyro.plate 独立上下文中,这种枚举可以并行发生:我们只枚举 2 种可能性,而不是 2**len(data) = 32。最后,要使用并行版本的枚举,我们通过 max_plate_nesting=1 告知 Pyro 我们只使用一个 plate;这让 Pyro 知道我们正在使用最右边的维度 plate,并且 Pyro 可以使用任何其他维度进行并行化。

[27]:
optim = pyro.optim.Adam({"lr": 0.1, "betas": [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)

在推断之前,我们将初始化为合理的值。混合模型非常容易陷入局部模式。一种常见的方法是从许多随机初始化中选择最好的一个,其中聚类均值从数据的随机子样本中初始化。由于我们使用的是 AutoDelta 引导,我们可以通过定义自定义的 init_loc_fn() 来初始化。

[28]:
def init_loc_fn(site):
    if site["name"] == "weights":
        # Initialize weights to uniform.
        return torch.ones(K) / K
    if site["name"] == "scale":
        return (data.var() / 2).sqrt()
    if site["name"] == "locs":
        return data[torch.multinomial(torch.ones(len(data)) / len(data), K)]
    raise ValueError(site["name"])


def initialize(seed):
    global global_guide, svi
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    global_guide = AutoDelta(
        poutine.block(model, expose=["weights", "locs", "scale"]),
        init_loc_fn=init_loc_fn,
    )
    svi = SVI(model, global_guide, optim, loss=elbo)
    return svi.loss(model, global_guide, data)


# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in range(100))
initialize(seed)
print(f"seed = {seed}, initial_loss = {loss}")
seed = 13, initial_loss = 25.665584564208984

在训练过程中,我们将收集损失和梯度范数以监控收敛。我们可以使用 PyTorch 的 .register_hook() 方法来实现这一点。

[29]:
# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(
        lambda g, name=name: gradient_norms[name].append(g.norm().item())
    )

losses = []
for i in range(200 if not smoke_test else 2):
    loss = svi.step(data)
    losses.append(loss)
    print("." if i % 100 else "\n", end="")
...................................................................................................
...................................................................................................
[30]:
pyplot.figure(figsize=(10, 3), dpi=100).set_facecolor("white")
pyplot.plot(losses)
pyplot.xlabel("iters")
pyplot.ylabel("loss")
pyplot.yscale("log")
pyplot.title("Convergence of SVI");
_images/gmm_12_0.png
[31]:
pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor("white")
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel("iters")
pyplot.ylabel("gradient norm")
pyplot.yscale("log")
pyplot.legend(loc="best")
pyplot.title("Gradient norms during SVI");
_images/gmm_13_0.png

这是学习到的参数

[32]:
map_estimates = global_guide(data)
weights = map_estimates["weights"]
locs = map_estimates["locs"]
scale = map_estimates["scale"]
print(f"weights = {weights.data.numpy()}")
print(f"locs = {locs.data.numpy()}")
print(f"scale = {scale.data.numpy()}")
weights = [0.625 0.375]
locs = [10.984464    0.49901518]
scale = 0.6514337062835693

模型的 weights 符合预期,大约 2/5 的数据在第一个分量中,3/5 在第二个分量中。接下来让我们可视化混合模型。

[33]:
X = np.arange(-3, 15, 0.1)
Y1 = weights[0].item() * scipy.stats.norm.pdf((X - locs[0].item()) / scale.item())
Y2 = weights[1].item() * scipy.stats.norm.pdf((X - locs[1].item()) / scale.item())

pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor("white")
pyplot.plot(X, Y1, "r-")
pyplot.plot(X, Y2, "b-")
pyplot.plot(X, Y1 + Y2, "k--")
pyplot.plot(data.data.numpy(), np.zeros(len(data)), "k*")
pyplot.title("Density of two-component mixture model")
pyplot.ylabel("probability density");
_images/gmm_17_0.png

最后请注意,混合模型的优化是非凸的,并且经常会陷入局部最优。例如,在本教程中,我们观察到如果 scale 初始化得太大,混合模型会陷入所有数据都在一个聚类中的假设。

服务模型:预测归属

既然我们已经训练了一个混合模型,我们可能想将模型用作分类器。在训练过程中,我们边缘化了模型中的分配变量。虽然这提供了快速收敛,但它阻止我们从引导中读取聚类分配。我们将讨论两种将模型视为分类器的方法:第一种是使用 infer_discrete(快得多),第二种是通过在 SVI 中使用枚举训练辅助引导(较慢但更通用)。

使用离散推断预测归属

预测归属的最快方法是使用 infer_discrete 处理器,以及 tracereplay。让我们从 MAP 分类器开始,将 infer_discrete 的温度参数设置为零。要深入了解 tracereplayinfer_discrete 等效应处理器,请参阅效应处理器教程

[34]:
guide_trace = poutine.trace(global_guide).get_trace(data)  # record the globals
trained_model = poutine.replay(model, trace=guide_trace)  # replay the globals


def classifier(data, temperature=0):
    inferred_model = infer_discrete(
        trained_model, temperature=temperature, first_available_dim=-2
    )  # avoid conflict with data plate
    trace = poutine.trace(inferred_model).get_trace(data)
    return trace.nodes["assignment"]["value"]


print(classifier(data))
tensor([1, 1, 0, 0, 0])

实际上,我们可以在新数据上运行此分类器

[35]:
new_data = torch.arange(-3, 15, 0.1)
assignment = classifier(new_data)
pyplot.figure(figsize=(8, 2), dpi=100).set_facecolor("white")
pyplot.plot(new_data.numpy(), assignment.numpy())
pyplot.title("MAP assignment")
pyplot.xlabel("data value")
pyplot.ylabel("class assignment");
_images/gmm_21_0.png

为了生成随机后验分配而不是 MAP 分配,我们可以设置 temperature=1

[36]:
print(classifier(data, temperature=1))
tensor([1, 1, 0, 0, 0])

由于类别分离得非常好,我们放大到类别边界附近,大约在 5.75。

[37]:
new_data = torch.arange(5.5, 6.0, 0.005)
assignment = classifier(new_data, temperature=1)
pyplot.figure(figsize=(8, 2), dpi=100).set_facecolor("white")
pyplot.plot(new_data.numpy(), assignment.numpy(), "x", color="C0")
pyplot.title("Random posterior assignment")
pyplot.xlabel("data value")
pyplot.ylabel("class assignment");
_images/gmm_25_0.png

通过在引导中枚举来预测归属

预测类别归属的第二种方法是在引导中枚举。这对于服务分类器模型效果不好,因为我们需要对每个新的输入数据批次运行随机优化,但它更通用,可以嵌入到更大的变分模型中。

为了从引导中读取聚类分配,我们将定义一个新的 full_guide,它拟合全局参数(如上所述)和局部参数(之前被边缘化)。由于我们已经学习了全局变量的良好值,我们将使用 poutine.block 阻止 SVI 更新这些变量。

[38]:
@config_enumerate
def full_guide(data):
    # Global variables.
    with poutine.block(
        hide_types=["param"]
    ):  # Keep our learned values of global parameters.
        global_guide(data)

    # Local variables.
    with pyro.plate("data", len(data)):
        assignment_probs = pyro.param(
            "assignment_probs",
            torch.ones(len(data), K) / K,
            constraint=constraints.simplex,
        )
        pyro.sample("assignment", dist.Categorical(assignment_probs))
[39]:
optim = pyro.optim.Adam({"lr": 0.2, "betas": [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, full_guide, optim, loss=elbo)

# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
svi.loss(model, full_guide, data)  # Initializes param store.
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(
        lambda g, name=name: gradient_norms[name].append(g.norm().item())
    )

losses = []
for i in range(200 if not smoke_test else 2):
    loss = svi.step(data)
    losses.append(loss)
    print("." if i % 100 else "\n", end="")
...................................................................................................
...................................................................................................
[40]:
pyplot.figure(figsize=(10, 3), dpi=100).set_facecolor("white")
pyplot.plot(losses)
pyplot.xlabel("iters")
pyplot.ylabel("loss")
pyplot.yscale("log")
pyplot.title("Convergence of SVI");
_images/gmm_29_0.png
[41]:
pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor("white")
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel("iters")
pyplot.ylabel("gradient norm")
pyplot.yscale("log")
pyplot.legend(loc="best")
pyplot.title("Gradient norms during SVI");
_images/gmm_30_0.png

现在我们可以检查引导的局部变量 assignment_probs

[42]:
assignment_probs = pyro.param("assignment_probs")
pyplot.figure(figsize=(8, 3), dpi=100).set_facecolor("white")
pyplot.plot(
    data.data.numpy(),
    assignment_probs.data.numpy()[:, 0],
    "ro",
    label=f"component with mean {locs[0]:0.2g}",
)
pyplot.plot(
    data.data.numpy(),
    assignment_probs.data.numpy()[:, 1],
    "bo",
    label=f"component with mean {locs[1]:0.2g}",
)
pyplot.title("Mixture assignment probabilities")
pyplot.xlabel("data value")
pyplot.ylabel("assignment probability")
pyplot.legend(loc="center");
_images/gmm_32_0.png

MCMC

接下来我们将使用折叠 NUTS 探索分量参数上的完整后验,即我们将使用 NUTS 并边缘化所有离散潜变量。

[43]:
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS

pyro.set_rng_seed(2)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=250, warmup_steps=50)
mcmc.run(data)
posterior_samples = mcmc.get_samples()
Sample: 100%|██████████████████████████████████████████| 300/300 [00:12, 23.57it/s, step size=4.38e-01, acc. prob=0.951]
[44]:
X, Y = posterior_samples["locs"].T
[45]:
pyplot.figure(figsize=(8, 8), dpi=100).set_facecolor("white")
h, xs, ys, image = pyplot.hist2d(X.numpy(), Y.numpy(), bins=[20, 20])
pyplot.contour(
    np.log(h + 3).T,
    extent=[xs.min(), xs.max(), ys.min(), ys.max()],
    colors="white",
    alpha=0.8,
)
pyplot.title("Posterior density as estimated by collapsed NUTS")
pyplot.xlabel("loc of component 0")
pyplot.ylabel("loc of component 1")
pyplot.tight_layout()
_images/gmm_36_0.png

请注意,由于混合分量的不可辨识性,似然曲面有两个可能性相等的模式,分别接近 (11,0.5)(0.5,11)。NUTS 在两个模式之间切换时存在困难。

[46]:
pyplot.figure(figsize=(8, 3), dpi=100).set_facecolor("white")
pyplot.plot(X.numpy(), color="red")
pyplot.plot(Y.numpy(), color="blue")
pyplot.xlabel("NUTS step")
pyplot.ylabel("loc")
pyplot.title("Trace plot of loc parameter during NUTS inference")
pyplot.tight_layout()
_images/gmm_38_0.png