SVI 第三部分:ELBO 梯度估计器

设置

我们定义了一个 Pyro 模型,其观测值 \({\bf x}\) 和隐变量 \({\bf z}\) 的形式为 \(p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})\)。我们还定义了一个 Pyro 引导(即变分分布),形式为 \(q_{\phi}({\bf z})\)。这里 \({\theta}\)\(\phi\) 分别是模型和引导的变分参数。(特别地,这些 不是 需要贝叶斯处理的随机变量)。

我们希望通过最大化 ELBO(证据下界)来最大化对数证据 \(\log p_{\theta}({\bf x})\),ELBO 由以下公式给出

\[{\rm ELBO} \equiv \mathbb{E}_{q_{\phi}({\bf z})} \left [ \log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z}) \right]\]

为此,我们将在参数空间 \(\{ \theta, \phi \}\) 中对 ELBO 执行(随机)梯度步长(有关此方法的早期工作,请参阅参考文献 [1,2])。因此,我们需要能够计算以下项的无偏估计

\[\nabla_{\theta,\phi} {\rm ELBO} = \nabla_{\theta,\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ \log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z}) \right]\]

对于一般随机函数 model()guide(),我们如何做到这一点?为了简化符号,我们将讨论稍微概括一下,并探讨如何计算任意成本函数 \(f({\bf z})\) 的期望梯度。我们也不再区分 \(\theta\)\(\phi\)。因此,我们希望计算

\[\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]\]

让我们从最简单的情况开始。

简单情况:可重参数化的随机变量

假设我们可以进行重参数化,使得

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [f_{\phi}({\bf z}) \right] =\mathbb{E}_{q({\bf \epsilon})} \left [f_{\phi}(g_{\phi}({\bf \epsilon})) \right]\]

关键在于,我们将所有对 \(\phi\) 的依赖都移到了期望内部;\(q({\bf \epsilon})\) 是一个固定分布,不依赖于 \(\phi\)。这种重参数化可以应用于许多分布(例如正态分布);有关讨论,请参阅参考文献 [3]。在这种情况下,我们可以将梯度直接通过期望传递,得到

\[\nabla_{\phi}\mathbb{E}_{q({\bf \epsilon})} \left [f_{\phi}(g_{\phi}({\bf \epsilon})) \right]= \mathbb{E}_{q({\bf \epsilon})} \left [\nabla_{\phi}f_{\phi}(g_{\phi}({\bf \epsilon})) \right]\]

假设 \(f(\cdot)\)\(g(\cdot)\) 足够平滑,我们现在可以通过对该期望进行蒙特卡罗估计来获得感兴趣的梯度的无偏估计。

棘手情况:不可重参数化的随机变量

如果我们无法进行上述重参数化怎么办?不幸的是,许多感兴趣的分布(例如所有离散分布)都属于这种情况。在这种情况下,我们的估计器形式会稍微复杂一些。

我们首先将感兴趣的梯度展开为

\[\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]= \nabla_{\phi} \int d{\bf z} \; q_{\phi}({\bf z}) f_{\phi}({\bf z})\]

并使用链式法则将其写为

\[\int d{\bf z} \; \left \{ (\nabla_{\phi} q_{\phi}({\bf z})) f_{\phi}({\bf z}) + q_{\phi}({\bf z})(\nabla_{\phi} f_{\phi}({\bf z}))\right \}\]

在这一点上,我们遇到了一个问题。我们知道如何从 \(q(\cdot)\) 中生成样本——我们只需正向运行引导——但 \(\nabla_{\phi} q_{\phi}({\bf z})\) 甚至不是一个有效的概率密度。因此,我们需要调整这个公式,使其成为关于 \(q(\cdot)\) 的期望形式。这可以使用以下恒等式轻松完成

\[ \nabla_{\phi} q_{\phi}({\bf z}) = q_{\phi}({\bf z})\nabla_{\phi} \log q_{\phi}({\bf z})\]

这允许我们将感兴趣的梯度重写为

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [ (\nabla_{\phi} \log q_{\phi}({\bf z})) f_{\phi}({\bf z}) + \nabla_{\phi} f_{\phi}({\bf z})\right]\]

这种形式的梯度估计器——也被称为 REINFORCE 估计器、评分函数估计器或似然比估计器——适用于简单的蒙特卡罗估计。

请注意,打包此结果的一种方法(方便实现)是引入一个替代目标函数

\[{\rm surrogate \;objective} \equiv \log q_{\phi}({\bf z}) \overline{f_{\phi}({\bf z})} + f_{\phi}({\bf z})\]

这里的横线表示该项保持不变(即不对 \(\phi\) 求导)。为了获得(单样本)蒙特卡罗梯度估计,我们对隐随机变量进行采样,计算替代目标,然后求导。结果是 \(\nabla_{\phi}\mathbb{E}_{q_{\phi}({\bf z})} \left [ f_{\phi}({\bf z}) \right]\) 的无偏估计。用公式表示

\[\nabla_{\phi} {\rm ELBO} = \mathbb{E}_{q_{\phi}({\bf z})} \left [ \nabla_{\phi} ({\rm surrogate \; objective}) \right]\]

方差或我为什么希望我在做 MLE 深度学习

我们现在有了一个用于计算成本函数期望的无偏梯度估计器的通用方法。不幸的是,在 \(q(\cdot)\) 包含不可重参数化随机变量的更一般情况下,这个估计器往往具有高方差。事实上,在许多感兴趣的情况下,方差非常高,导致估计器实际上无法使用。因此,我们需要降低方差的策略(有关讨论,请参阅参考文献 [4])。我们将采用两种策略。第一种策略利用了成本函数 \(f(\cdot)\) 的特殊结构。第二种策略有效地引入了一种通过使用先前对 \(\mathbb{E}_{q_{\phi}({\bf z})} [ f_{\phi}({\bf z})]\) 的估计信息来降低方差的方法。因此,它有点类似于随机梯度下降中的动量。

通过依赖结构降低方差

在上面的讨论中,我们坚持使用一般的成本函数 \(f_{\phi}({\bf z})\)。我们可以沿着这条思路继续下去(我们即将讨论的方法适用于一般情况),但为了具体起见,让我们回到变分推断。在随机变分推断中,我们感兴趣的是一种特定形式的成本函数

\[\log p_{\theta}({\bf x} | {\rm Pa}_p ({\bf x})) + \sum_i \log p_{\theta}({\bf z}_i | {\rm Pa}_p ({\bf z}_i)) - \sum_i \log q_{\phi}({\bf z}_i | {\rm Pa}_q ({\bf z}_i))\]

这里我们将对数比率 \(\log p_{\theta}({\bf x}, {\bf z})/q_{\phi}({\bf z})\) 分解为一个观测值对数似然项和对不同隐随机变量 \(\{{\bf z}_i \}\) 的求和。我们还引入了记号 \({\rm Pa}_p (\cdot)\)\({\rm Pa}_q (\cdot)\) 分别表示给定随机变量在模型和引导中的父节点。(读者可能会担心在一般随机函数的情况下,适当的依赖概念是什么;这里我们仅指单个执行轨迹内的常规依赖)。关键在于成本函数中的不同项对随机变量 \(\{ {\bf z}_i \}\) 具有不同的依赖关系,这是我们可以利用的。

长话短说,对于任何不可重参数化的隐随机变量 \({\bf z}_i\),替代目标将包含一个项

\[\log q_{\phi}({\bf z}_i) \overline{f_{\phi}({\bf z})}\]

事实证明,我们可以移除 \(\overline{f_{\phi}({\bf z})}\) 中的一些项,并且仍然获得无偏梯度估计器;此外,这样做通常会降低方差。具体来说(详情请参阅参考文献 [4]),我们可以移除 \(\overline{f_{\phi}({\bf z})}\) 中任何不在隐变量 \({\bf z}_i\) 下游的项(下游是指相对于引导的依赖结构)。请注意,这种一般技巧——其中某些随机变量被解析处理以降低方差——通常被称为 Rao-Blackwell 化。

在 Pyro 中,所有这些逻辑都由 SVI 类自动处理。特别是,只要我们使用 TraceGraph_ELBO 损失,Pyro 就会跟踪模型和引导执行轨迹中的依赖结构,并构建一个已移除所有不必要项的替代目标。

svi = SVI(model, guide, optimizer, TraceGraph_ELBO())

请注意,利用此依赖信息可能会带来很小的计算开销,因此 TraceGraph_ELBO 仅应在模型包含不可重参数化随机变量的情况下使用;在大多数应用中,Trace_ELBO 就足够了。

一个 Rao-Blackwell 化示例:

假设我们有一个包含 \(K\) 个组成部分的高斯混合模型。对于每个数据点,我们:(i) 首先采样组成部分分布 \(k \in [1,...,K]\);(ii) 使用第 \(k^{\rm th}\) 个组成部分分布观察数据点。编写这种模型的最简单方法如下:

ks = pyro.sample("k", dist.Categorical(probs)
                          .to_event(1))
pyro.sample("obs", dist.Normal(locs[ks], scale)
                       .to_event(1),
            obs=data)

由于用户没有注意标记模型中的任何条件独立性,Pyro 的 SVI 类构建的梯度估计器无法利用 Rao-Blackwell 化,导致梯度估计器往往具有高方差。为了解决这个问题,用户需要显式标记条件独立性。令人高兴的是,这并不是很多工作

# mark conditional independence
# (assumed to be along the rightmost tensor dimension)
with pyro.plate("foo", data.size(-1)):
    ks = pyro.sample("k", dist.Categorical(probs))
    pyro.sample("obs", dist.Normal(locs[ks], scale),
                obs=data)

就是这样。

旁注:Pyro 中的依赖跟踪

最后,谈谈依赖跟踪。Pyro 使用来源(provenance)的概念来跟踪包含任意 Python 代码的随机函数中的依赖关系(参见参考文献 [5])。在编程语言理论中,变量的来源是指影响其值的变量或计算的历史。下面的简单示例展示了在 Pyro 中如何通过 PyTorch 操作跟踪来源,其中来源是用户定义的 frozenset 对象

from pyro.ops.provenance import get_provenance, track_provenance

a = track_provenance(torch.randn(3), frozenset({"a"}))
b = track_provenance(torch.randn(3), frozenset({"b"}))
c = torch.randn(3)  # no provenance information

# For a unary operation, the provenance of the output tensor
# equals the provenace of the input tensor
assert get_provenance(a.exp()) == frozenset({"a"})
# In general, the provenance of the output tensors of any op
# is the union of provenances of input tensors.
assert get_provenance(a * (b + c)) == frozenset({"a", "b"})

TraceGraph_ELBO 利用这一概念,通过中间计算跟踪不可重参数化随机变量的细粒度动态依赖信息,因为这些计算共同形成了对数似然。在内部,不可重参数化采样点使用 TrackNonReparam messenger 进行跟踪

def model():
    probs_a = torch.tensor([0.3, 0.7])
    probs_b = torch.tensor([[0.1, 0.9], [0.8, 0.2]])
    probs_c = torch.tensor([[0.5, 0.5], [0.6, 0.4]])
    a = pyro.sample("a", dist.Categorical(probs_a))
    b = pyro.sample("b", dist.Categorical(probs_b[a]))
    pyro.sample("c", dist.Categorical(probs_c[b]), obs=torch.tensor(0))

with TrackNonReparam():
    model_tr = trace(model).get_trace()
model_tr.compute_log_prob()

assert get_provenance(model_tr.nodes["a"]["log_prob"]) == frozenset({'a'})
assert get_provenance(model_tr.nodes["b"]["log_prob"]) == frozenset({'b', 'a'})
assert get_provenance(model_tr.nodes["c"]["log_prob"]) == frozenset({'b', 'a'})

使用依赖数据的基线降低方差

降低 ELBO 梯度估计器方差的第二种策略称为基线(参见例如参考文献 [6])。它实际上利用了与上述方差降低策略相同的数学原理,只不过现在我们不是移除项,而是添加项。基本上,与其移除期望为零但倾向于贡献方差的项,我们将添加经过特殊选择的、期望为零但作用是降低方差的项。因此,这是一种控制变量策略。

更详细地说,其思想是利用这样一个事实:对于任何常数 \(b\),以下恒等式成立

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [\nabla_{\phi} (\log q_{\phi}({\bf z}) \times b) \right]=0\]

这是因为 \(q(\cdot)\) 是归一化的,所以

\[\mathbb{E}_{q_{\phi}({\bf z})} \left [\nabla_{\phi} \log q_{\phi}({\bf z}) \right]= \int \!d{\bf z} \; q_{\phi}({\bf z}) \nabla_{\phi} \log q_{\phi}({\bf z})= \int \! d{\bf z} \; \nabla_{\phi} q_{\phi}({\bf z})= \nabla_{\phi} \int \! d{\bf z} \; q_{\phi}({\bf z})=\nabla_{\phi} 1 = 0\]

这意味着我们可以替换任何项

\[\log q_{\phi}({\bf z}_i) \overline{f_{\phi}({\bf z})}\]

在我们的替代目标中替换为

\[\log q_{\phi}({\bf z}_i) \left(\overline{f_{\phi}({\bf z})}-b\right)\]

这样做不会影响我们梯度估计器的均值,但会影响方差。如果我们明智地选择 \(b\),我们有望降低方差。事实上,\(b\) 不需要是常数:它可以取决于 \({\bf z}_i\) 上游(或同侧)的任何随机选择。

Pyro 中的基线

用户可以通过几种方式指示 Pyro 在随机变分推断中使用基线。由于基线可以附加到任何不可重参数化的随机变量上,当前的基线接口位于 pyro.sample 语句级别。具体来说,基线接口使用参数 baseline,它是一个指定基线选项的字典。请注意,仅在引导内的采样语句(而非模型中)指定基线才有意义。

衰减平均基线

最简单的基线是通过 \(\overline{f_{\phi}({\bf z})}\) 的近期样本的移动平均构造的。在 Pyro 中,可以通过如下方式调用这种基线

z = pyro.sample("z", dist.Bernoulli(...),
                infer=dict(baseline={'use_decaying_avg_baseline': True,
                                     'baseline_beta': 0.95}))

可选参数 baseline_beta 指定衰减平均的衰减率(默认值:0.90)。

神经基线

在某些情况下,衰减平均基线效果很好。在其他情况下,使用依赖于上游随机性的基线对于获得良好的方差降低至关重要。构建此类基线的一个强大方法是使用一个可以在学习过程中进行调整的神经网络。Pyro 提供了两种指定此类基线的方法(有关扩展示例,请参阅 AIR 教程)。

首先,用户需要确定基线将接收哪些输入(例如当前考虑的数据点或先前采样的随机变量)。然后,用户需要构建一个封装基线计算的 nn.Module。这可能看起来像这样

class BaselineNN(nn.Module):
    def __init__(self, dim_input, dim_hidden):
        super().__init__()
        self.linear = nn.Linear(dim_input, dim_hidden)
        # ... finish initialization ...

    def forward(self, x):
        hidden = self.linear(x)
        # ... do more computations ...
        return baseline

然后,假设 BaselineNN 对象 baseline_module 已在其他地方初始化,在引导中我们将有类似以下的代码

def guide(x):  # here x is the current mini-batch of data
    pyro.module("my_baseline", baseline_module)
    # ... other computations ...
    z = pyro.sample("z", dist.Bernoulli(...),
                    infer=dict(baseline={'nn_baseline': baseline_module,
                                         'nn_baseline_input': x}))

这里的参数 nn_baseline 告诉 Pyro 使用哪个 nn.Module 来构建基线。在后端,参数 nn_baseline_input 被馈送到模块的前向方法中,以计算基线 \(b\)。请注意,基线模块需要通过 pyro.module 调用向 Pyro 注册,以便 Pyro 了解模块内的可训练参数。

在底层,Pyro 构建了一个以下形式的损失函数

\[{\rm baseline\; loss} \equiv\left(\overline{f_{\phi}({\bf z})} - b \right)^2\]

该损失函数用于调整神经网络的参数。没有定理表明这是在这种情况下使用的最佳损失函数(它不是),但在实践中它可能效果很好。就像衰减平均基线一样,其思想是能够跟踪均值 \(\overline{f_{\phi}({\bf z})}\) 的基线将有助于降低方差。在底层,SVI 在 ELBO 步骤的同时,对基线损失也执行一个步骤。

请注意,在实践中,对基线参数使用不同的学习超参数集(例如,更高的学习率)可能很重要。在 Pyro 中,这可以通过以下方式完成

def per_param_args(param_name):
    if 'baseline' in param_name:
        return {"lr": 0.010}
    else:
        return {"lr": 0.001}

optimizer = optim.Adam(per_param_args)

请注意,为了使整个过程正确,基线参数应仅通过基线损失进行优化。同样,模型和引导参数应仅通过 ELBO 进行优化。为了确保在底层实现这一点,SVI 会将进入 ELBO 的基线 \(b\) 从 autograd 图中分离。此外,由于神经基线的输入可能依赖于模型和引导的参数,因此在将输入馈送到神经网络之前,输入也会从 autograd 图中分离。

最后,用户还有另一种指定神经基线的方法。只需使用参数 baseline_value

b = # do baseline computation
z = pyro.sample("z", dist.Bernoulli(...),
                infer=dict(baseline={'baseline_value': b}))

这与上述方式相同,只是在这种情况下,用户有责任确保连接 \(b\) 与模型和引导参数的任何 autograd 记录都被切断。或者换句话说,使用 PyTorch 用户更熟悉的语言来说,任何依赖于 \(\theta\)\(\phi\)\(b\) 输入都需要使用 detach() 语句从 autograd 图中分离。

一个包含基线的完整示例

回想一下,在第一个 SVI 教程中,我们考虑了一个用于模拟抛硬币的 Bernoulli-Beta 模型。由于 Beta 随机变量不可重参数化(或者说不容易重参数化),相应的 ELBO 梯度可能会相当嘈杂。在这种情况下,我们通过使用提供(近似)重参数化梯度的 Beta 分布来解决这个问题。在这里,我们展示了一个简单的衰减平均基线如何在 Beta 分布被视为不可重参数化(从而 ELBO 梯度估计器是评分函数类型)的情况下降低方差。同时,我们还使用 plate 以完全向量化的方式编写模型。

我们不直接比较梯度方差,而是看看 SVI 需要多少步才能收敛。回想一下,对于这个特定模型(由于共轭性),我们可以计算精确的后验。因此,为了评估基线在这种情况下的效用,我们设置了以下简单实验。我们在指定的一组变分参数下初始化引导。然后我们进行 SVI,直到变分参数达到精确后验参数的固定容差范围内。我们在使用和不使用衰减平均基线的情况下都这样做。然后我们比较两种情况下所需的梯度步数。以下是完整的代码

(由于除了使用 plate use_decaying_avg_baseline 之外,这部分代码与 SVI 教程的第一部分和第二部分的代码非常相似,因此我们不再逐行讲解代码。)

[ ]:
import os
import torch
import torch.distributions.constraints as constraints
import pyro
import pyro.distributions as dist
# Pyro also has a reparameterized Beta distribution so we import
# the non-reparameterized version to make our point
from pyro.distributions.testing.fakes import NonreparameterizedBeta
import pyro.optim as optim
from pyro.infer import SVI, TraceGraph_ELBO
import sys

assert pyro.__version__.startswith('1.9.1')

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
max_steps = 2 if smoke_test else 10000


def param_abs_error(name, target):
    return torch.sum(torch.abs(target - pyro.param(name))).item()


class BernoulliBetaExample:
    def __init__(self, max_steps):
        # the maximum number of inference steps we do
        self.max_steps = max_steps
        # the two hyperparameters for the beta prior
        self.alpha0 = 10.0
        self.beta0 = 10.0
        # the dataset consists of six 1s and four 0s
        self.data = torch.zeros(10)
        self.data[0:6] = torch.ones(6)
        self.n_data = self.data.size(0)
        # compute the alpha parameter of the exact beta posterior
        self.alpha_n = self.data.sum() + self.alpha0
        # compute the beta parameter of the exact beta posterior
        self.beta_n = - self.data.sum() + torch.tensor(self.beta0 + self.n_data)
        # initial values of the two variational parameters
        self.alpha_q_0 = 15.0
        self.beta_q_0 = 15.0

    def model(self, use_decaying_avg_baseline):
        # sample `latent_fairness` from the beta prior
        f = pyro.sample("latent_fairness", dist.Beta(self.alpha0, self.beta0))
        # use plate to indicate that the observations are
        # conditionally independent given f and get vectorization
        with pyro.plate("data_plate"):
            # observe all ten datapoints using the bernoulli likelihood
            pyro.sample("obs", dist.Bernoulli(f), obs=self.data)

    def guide(self, use_decaying_avg_baseline):
        # register the two variational parameters with pyro
        alpha_q = pyro.param("alpha_q", torch.tensor(self.alpha_q_0),
                             constraint=constraints.positive)
        beta_q = pyro.param("beta_q", torch.tensor(self.beta_q_0),
                            constraint=constraints.positive)
        # sample f from the beta variational distribution
        baseline_dict = {'use_decaying_avg_baseline': use_decaying_avg_baseline,
                         'baseline_beta': 0.90}
        # note that the baseline_dict specifies whether we're using
        # decaying average baselines or not
        pyro.sample("latent_fairness", NonreparameterizedBeta(alpha_q, beta_q),
                    infer=dict(baseline=baseline_dict))

    def do_inference(self, use_decaying_avg_baseline, tolerance=0.80):
        # clear the param store in case we're in a REPL
        pyro.clear_param_store()
        # setup the optimizer and the inference algorithm
        optimizer = optim.Adam({"lr": .0005, "betas": (0.93, 0.999)})
        svi = SVI(self.model, self.guide, optimizer, loss=TraceGraph_ELBO())
        print("Doing inference with use_decaying_avg_baseline=%s" % use_decaying_avg_baseline)

        # do up to this many steps of inference
        for k in range(self.max_steps):
            svi.step(use_decaying_avg_baseline)
            if k % 100 == 0:
                print('.', end='')
                sys.stdout.flush()

            # compute the distance to the parameters of the true posterior
            alpha_error = param_abs_error("alpha_q", self.alpha_n)
            beta_error = param_abs_error("beta_q", self.beta_n)

            # stop inference early if we're close to the true posterior
            if alpha_error < tolerance and beta_error < tolerance:
                break

        print("\nDid %d steps of inference." % k)
        print(("Final absolute errors for the two variational parameters " +
               "were %.4f & %.4f") % (alpha_error, beta_error))

# do the experiment
bbe = BernoulliBetaExample(max_steps=max_steps)
bbe.do_inference(use_decaying_avg_baseline=True)
bbe.do_inference(use_decaying_avg_baseline=False)

样本输出

Doing inference with use_decaying_avg_baseline=True
....................
Did 1932 steps of inference.
Final absolute errors for the two variational parameters were 0.7997 & 0.0800
Doing inference with use_decaying_avg_baseline=False
..................................................
Did 4908 steps of inference.
Final absolute errors for the two variational parameters were 0.7991 & 0.2532

在这次特定的运行中,我们可以看到基线大致将 SVI 所需的步骤数减半。结果是随机的,并且会因运行而异,但这仍然是一个令人鼓舞的结果。这是一个相当刻意的示例,但对于某些模型和引导对,基线可以带来显著的优势。

参考文献

[1] 概率编程中的自动化变分推断,      David Wingate, Theo Weber

[2] 黑盒变分推断,     Rajesh Ranganath, Sean Gerrish, David M. Blei

[3] 自编码变分贝叶斯,     Diederik P Kingma, Max Welling

[4] 使用随机计算图的梯度估计,      John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

[5] 用于高效推断的概率程序的非标准解释      David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind

[6] 信念网络中的神经变分推断与学习      Andriy Mnih, Karol Gregor