定制 SVI 目标和训练循环

Pyro 支持各种基于优化的贝叶斯推断方法,其中 Trace_ELBO 作为 SVI(随机变分推断)的基础实现。关于各种 SVI 实现的更多信息,请参阅文档;关于 SVI 的背景知识,请参阅 SVI 教程IIIIII

在本教程中,我们将展示高级用户如何修改和/或增强 Pyro 提供的变分目标(或称:损失函数)和训练步骤实现,以支持特殊用例。

  1. SVI 基本用法

    1. 一种更底层的方式

  2. 示例:自定义正则化器

  3. 示例:缩放损失

  4. 示例:Beta VAE

  5. 示例:混合优化器

  6. 示例:自定义 ELBO

  7. 示例:KL 退火

SVI 基本用法

我们首先回顾 Pyro 中 SVI 对象的基本使用模式。我们假设用户已经定义了一个 model 和一个 guide。然后用户创建一个优化器和一个 SVI 对象

optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)})
svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())

然后可以通过调用 svi.step(...) 来执行梯度步骤。step() 的参数会传递给 modelguide

一种更底层的方式

上述模式的好处在于它允许 Pyro 帮我们处理各种细节,例如

  • 当遇到新参数时,pyro.optim.Adam 会动态创建一个新的 torch.optim.Adam 优化器

  • SVI.step() 会在梯度步骤之间将梯度清零

如果我们想要更多控制,可以直接操作各种 ELBO 类的可微分损失方法。例如,这个优化循环

svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())
for i in range(n_iter):
    loss = svi.step(X_train, y_train)

等价于这种底层方式

loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)
with pyro.poutine.trace(param_only=True) as param_capture:
    loss = loss_fn(model, guide)
params = set(site["value"].unconstrained()
                for site in param_capture.trace.nodes.values())
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))
for i in range(n_iter):
    # compute loss
    loss = loss_fn(model, guide)
    loss.backward()
    # take a step and zero the parameter gradients
    optimizer.step()
    optimizer.zero_grad()

示例:自定义正则化器

假设我们想给 SVI 损失函数添加一个自定义正则化项。使用上述用法模式,这很容易做到。首先我们定义正则化器

def my_custom_L2_regularizer(my_parameters):
    reg_loss = 0.0
    for param in my_parameters:
        reg_loss = reg_loss + param.pow(2.0).sum()
    return reg_loss

然后我们唯一需要做的改变是

- loss = loss_fn(model, guide)
+ loss = loss_fn(model, guide) + my_custom_L2_regularizer(my_parameters)

示例:梯度裁剪

对于某些模型,损失梯度在训练过程中可能会爆炸,导致溢出和 NaN 值。一种防止这种情况的方法是梯度裁剪。pyro.optim 中的各种优化器接受一个可选的 clip_args 字典,它允许将梯度范数或梯度值裁剪到给定限制内。

改变上面基本示例的方式是

- optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)})
+ optimizer = pyro.optim.Adam({"lr": 0.001, "betas": (0.90, 0.999)}, {"clip_norm": 10.0})

还可以通过修改上述底层模式手动实现更多变体的梯度裁剪。

示例:缩放损失

根据优化算法的不同,损失函数的尺度可能重要也可能不重要。假设我们想在求导之前,按数据点数量缩放损失函数。这很容易做到

- loss = loss_fn(model, guide)
+ loss = loss_fn(model, guide) / N_data

注意,对于 SVI,损失函数中的每一项都是来自模型或 guide 的对数概率,可以使用 poutine.scale 实现同样的效果。例如,我们可以使用 poutine.scale 装饰器来缩放模型和 guide

@poutine.scale(scale=1.0/N_data)
def model(...):
    pass

@poutine.scale(scale=1.0/N_data)
def guide(...):
    pass

示例:Beta VAE

我们还可以使用 poutine.scale 构建非标准的 ELBO 变分目标,例如,其中 KL 散度相对于期望对数似然有不同的缩放。特别地,对于 Beta VAE,KL 散度通过因子 beta 进行缩放

def model(data, beta=0.5):
    z_loc, z_scale = ...
    with pyro.poutine.scale(scale=beta)
        z = pyro.sample("z", dist.Normal(z_loc, z_scale))
    pyro.sample("obs", dist.Bernoulli(...), obs=data)

def guide(data, beta=0.5):
    with pyro.poutine.scale(scale=beta)
        z_loc, z_scale = ...
        z = pyro.sample("z", dist.Normal(z_loc, z_scale))

选择这样的模型和 guide 后,与潜变量 z 对应的对数密度会进入变分目标的构建过程,通过

svi = pyro.infer.SVI(model, guide, optimizer, loss=pyro.infer.Trace_ELBO())

会被因子 beta 缩放,从而导致 KL 散度也按 beta 缩放。

示例:混合优化器

pyro.optim 中的各种优化器允许用户按参数指定优化设置(例如学习率)。但是如果想对不同参数使用不同的优化算法怎么办?可以使用 Pyro 的 MultiOptimizer(见下文),但如果直接操作 differentiable_loss 也可以达到同样的效果

adam = torch.optim.Adam(adam_parameters, {"lr": 0.001, "betas": (0.90, 0.999)})
sgd = torch.optim.SGD(sgd_parameters, {"lr": 0.0001})
loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
# compute loss
loss = loss_fn(model, guide)
loss.backward()
# take a step and zero the parameter gradients
adam.step()
sgd.step()
adam.zero_grad()
sgd.zero_grad()

为了完整起见,我们还展示了如何使用 MultiOptimizer 完成同样的事情,它允许我们组合多个 Pyro 优化器。注意,由于 MultiOptimizer 在内部使用 torch.autograd.grad(而不是 torch.Tensor.backward()),它的接口略有不同;特别是 step() 方法也接受参数作为输入。

def model():
    pyro.param('a', ...)
    pyro.param('b', ...)
    ...

adam = pyro.optim.Adam({'lr': 0.1})
sgd = pyro.optim.SGD({'lr': 0.01})
optim = MixedMultiOptimizer([(['a'], adam), (['b'], sgd)])
with pyro.poutine.trace(param_only=True) as param_capture:
    loss = elbo.differentiable_loss(model, guide)
params = {'a': pyro.param('a'), 'b': pyro.param('b')}
optim.step(loss, params)

示例:自定义 ELBO

在前三个示例中,我们绕过了创建 SVI 对象,而是直接操作由 ELBO 实现提供的可微分损失函数。我们还可以做另一件事,那就是创建自定义的 ELBO 实现并将它们传入 SVI 机制。例如,Trace_ELBO 损失函数的一个简化版本可能如下所示

# note that simple_elbo takes a model, a guide, and their respective arguments as inputs
def simple_elbo(model, guide, *args, **kwargs):
    # run the guide and trace its execution
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
    # run the model and replay it against the samples from the guide
    model_trace = poutine.trace(
        poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)
    # construct the elbo loss function
    return -1*(model_trace.log_prob_sum() - guide_trace.log_prob_sum())

svi = SVI(model, guide, optim, loss=simple_elbo)

注意,这基本上就是 “mini-pyro”elbo 实现的样子。

示例:KL 退火

深度马尔可夫模型教程中,ELBO 变分目标在训练期间被修改。特别是,潜随机变量之间的各种 KL 散度项相对于观测数据的对数概率被向下缩放(即退火)。在教程中,这是通过使用 poutine.scale 完成的。我们可以通过定义自定义损失函数来完成同样的事情。后一种选择不是非常优雅的方式,但我们仍然将其包含在内,以展示我们所拥有的灵活性。

def simple_elbo_kl_annealing(model, guide, *args, **kwargs):
    # get the annealing factor and latents to anneal from the keyword
    # arguments passed to the model and guide
    annealing_factor = kwargs.pop('annealing_factor', 1.0)
    latents_to_anneal = kwargs.pop('latents_to_anneal', [])
    # run the guide and replay the model against the guide
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
    model_trace = poutine.trace(
        poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)

    elbo = 0.0
    # loop through all the sample sites in the model and guide trace and
    # construct the loss; note that we scale all the log probabilities of
    # samples sites in `latents_to_anneal` by the factor `annealing_factor`
    for site in model_trace.values():
        if site["type"] == "sample":
            factor = annealing_factor if site["name"] in latents_to_anneal else 1.0
            elbo = elbo + factor * site["fn"].log_prob(site["value"]).sum()
    for site in guide_trace.values():
        if site["type"] == "sample":
            factor = annealing_factor if site["name"] in latents_to_anneal else 1.0
            elbo = elbo - factor * site["fn"].log_prob(site["value"]).sum()
    return -elbo

svi = SVI(model, guide, optim, loss=simple_elbo_kl_annealing)
svi.step(other_args, annealing_factor=0.2, latents_to_anneal=["my_latent"])