SVI 第二部分:条件独立、子抽样和均摊

目标:将 SVI 扩展到大型数据集

对于一个包含 \(N\) 个观测值的模型,运行 modelguide 并构建 ELBO 需要评估 log pdf,其复杂度随 \(N\) 增长而快速增加。如果我们要扩展到大型数据集,这是一个问题。幸运的是,只要我们的 model/guide 具有一些可以利用的条件独立结构,ELBO 目标函数自然支持子抽样。例如,在观测值给定隐变量是条件独立的情况下,ELBO 中的对数似然项可以通过以下方式近似:

\[ \sum_{i=1}^N \log p({\bf x}_i | {\bf z}) \approx \frac{N}{M} \sum_{i\in{\mathcal{I}_M}} \log p({\bf x}_i | {\bf z})\]

其中 \(\mathcal{I}_M\) 是大小为 \(M\) 的索引迷你批次,且 \(M<N\)(有关讨论,请参阅参考文献 [1,2])。太好了,问题解决了!但是在 Pyro 中如何实现呢?

在 Pyro 中标记条件独立性

如果用户想在 Pyro 中实现此类功能,首先需要确保 model 和 guide 的编写方式能够让 Pyro 利用相关的条件独立性。让我们看看如何实现。Pyro 提供了两个语言原语来标记条件独立性:platemarkov。我们先从两者中较简单的一个开始。

顺序 plate

让我们回到上一个教程中使用的示例。为了方便起见,我们在此复制 model 的主要逻辑:

def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data using pyro.sample with the obs keyword argument
    for i in range(len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

对于此模型,观测值在给定隐随机变量 latent_fairness 的条件下是条件独立的。要在 Pyro 中明确标记这一点,我们只需要将 Python 内置的 range 替换为 Pyro 的构造 plate

def model(data):
    # sample f from the beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data [WE ONLY CHANGE THE NEXT LINE]
    for i in pyro.plate("data_loop", len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

我们看到 pyro.platerange 非常相似,但有一个主要区别:每次调用 plate 都要求用户提供一个唯一的名称。第二个参数是一个整数,就像 range 一样。

到目前为止一切顺利。Pyro 现在可以利用给定隐随机变量的观测值的条件独立性了。但这实际上是如何工作的呢?基本上,pyro.plate 是使用上下文管理器实现的。在 for 循环体每次执行时,我们进入一个新的(条件)独立上下文,并在 for 循环体结束时退出该上下文。让我们非常明确地说明这一点:

  • 因为每个观测到的 pyro.sample 语句都发生在 for 循环体的不同执行中,Pyro 将每个观测标记为独立。

  • 这种独立性严格来说是一种条件独立性,*给定* latent_fairness,因为 latent_fairness 是*外部*于 data_loop 上下文采样的。

在继续之前,让我们提一下使用顺序 plate 时应避免的一些陷阱。考虑以下代码片段的变体:

# WARNING do not do this!
my_reified_list = list(pyro.plate("data_loop", len(data)))
for i in my_reified_list:
    pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

这*不会*达到期望的行为,因为 list() 在调用任何一个 pyro.sample 语句之前,会完全进入并退出 data_loop 上下文。类似地,用户需要注意不要让可变计算跨越上下文管理器的边界泄露,因为这可能导致微妙的错误。例如,pyro.plate 不适用于循环的每次迭代都依赖于上一次迭代的时序模型;在这种情况下,应该使用 rangepyro.markov 代替。

向量化 plate

从概念上讲,矢量化 plate 与顺序 plate 相同,区别在于它是一个矢量化操作(就像 torch.arange 相对于 range 一样)。因此,与顺序 plate 中出现的显式 for 循环相比,它可能实现巨大的速度提升。让我们看看这在我们的运行示例中是什么样子的。首先,我们需要 data 是张量的形式

data = torch.zeros(10)
data[0:6] = torch.ones(6)  # 6 heads and 4 tails

然后我们有

with pyro.plate('observe_data'):
    pyro.sample('obs', dist.Bernoulli(f), obs=data)

让我们将它与类比的顺序 plate 用法逐点比较

  • 这两种模式都要求用户指定一个唯一的名称。

  • 注意,这段代码只引入了一个(观测到的)随机变量(即 obs),因为整个张量被一次性考虑。

  • 由于在这种情况下不需要迭代器,因此不需要指定参与 plate 上下文的张量的长度

请注意,对于顺序 plate 提到的注意事项同样适用于矢量化 plate

子抽样

现在我们知道如何在Pyro中标记条件独立性。这本身就很有用(参见SVI第三部分的依赖跟踪章节),但我们也希望进行子采样,以便在大数据集上进行SVI。根据模型和指南的结构,Pyro支持几种进行子采样的方法。让我们逐一介绍这些方法。

使用 plate 进行自动子抽样

让我们先看最简单的情况,在这种情况下,通过给 plate 添加一两个额外参数,我们可以轻松获得子采样功能

for i in pyro.plate("data_loop", len(data), subsample_size=5):
    pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

就是这样了:我们只需要使用参数 subsample_size。每当我们运行 model() 时,现在只评估 data 中随机选择的5个数据点的对数似然;此外,对数似然将自动按适当的因子 \(\tfrac{10}{5} = 2\) 进行缩放。矢量化 plate 呢?其用法完全类似

with pyro.plate('observe_data', size=10, subsample_size=5) as ind:
    pyro.sample('obs', dist.Bernoulli(f),
                obs=data.index_select(0, ind))

重要的是,plate 现在返回一个索引张量 ind,在这种情况下,其长度将为5。注意,除了参数 subsample_size 之外,我们还传递了参数 size,这样 plate 才能知道张量 data 的完整大小,从而计算出正确的缩放因子。就像顺序 plate 一样,用户负责使用 plate 提供的索引来选择正确的数据点。

最后,注意如果 data 位于GPU上,用户必须将 device 参数传递给 plate

使用 plate 进行自定义子抽样策略

每次运行上述 model() 时,plate 都会采样新的子采样索引。由于这种子采样是无状态的,这可能导致一些问题:基本上,对于足够大的数据集,即使经过大量迭代,仍然存在不可忽略的概率,某些数据点从未被选中。为了避免这种情况,用户可以通过使用 platesubsample 参数来控制子采样。详情请参阅文档

仅存在局部随机变量时的子抽样

我们设想一个模型,其联合概率密度由下式给出:

\[p({\bf x}, {\bf z}) = \prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i)\]

对于具有这种依赖结构的模型,通过子采样本引入的缩放因子会以相同的量缩放 ELBO 中的所有项。例如,对于一个简单的 VAE 模型就是这种情况。这解释了为什么对于 VAE,用户可以完全控制子采样本,并将 mini-batches 直接传递给模型和 guide;plate 仍然使用,但 subsample_sizesubsample 则不使用。要详细了解其外观,请参阅 VAE 教程

同时存在全局和局部随机变量时的子抽样

在上面的抛硬币示例中,plate 出现在模型中,但没有出现在 guide 中,因为唯一被子采样的只有观测值。让我们看一个更复杂的示例,其中子采样本出现在模型和 guide 中。为了简化起见,我们让讨论保持抽象,避免编写完整的模型和 guide。

考虑由以下联合分布指定的模型

\[ p({\bf x}, {\bf z}, \beta) = p(\beta) \prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i | \beta)\]

\(N\) 个观测值 \(\{ {\bf x}_i \}\)\(N\) 个局部隐随机变量 \(\{ {\bf z}_i \}\)。还有一个全局隐随机变量 \(\beta\)。我们的 guide 将分解为

\[q({\bf z}, \beta) = q(\beta) \prod_{i=1}^N q({\bf z}_i | \beta, \lambda_i)\]

这里我们明确引入了 \(N\) 个局部变分参数 \(\{\lambda_i \}\),而其他变分参数则被省略。模型和 guide 都具有条件独立性。特别是在模型方面,给定 \(\{ {\bf z}_i \}\),观测值 \(\{ {\bf x}_i \}\) 是独立的。此外,给定 \(\beta\),隐随机变量 \(\{\bf {z}_i \}\) 是独立的。在 guide 方面,给定变分参数 \(\{\lambda_i \}\)\(\beta\),隐随机变量 \(\{\bf {z}_i \}\) 是独立的。为了在 Pyro 中标记这些条件独立性并进行子采样本,我们需要在模型guide 中都使用 plate。让我们使用顺序 plate 勾勒出基本逻辑(更完整的代码会包含 pyro.param 语句等)。首先是模型

def model(data):
    beta = pyro.sample("beta", ...) # sample the global RV
    for i in pyro.plate("locals", len(data)):
        z_i = pyro.sample("z_{}".format(i), ...)
        # compute the parameter used to define the observation
        # likelihood using the local random variable
        theta_i = compute_something(z_i)
        pyro.sample("obs_{}".format(i), dist.MyDist(theta_i), obs=data[i])

请注意,与我们正在进行的抛硬币示例相反,这里我们在 plate 循环内部和外部都有 pyro.sample 语句。接下来是 guide

def guide(data):
    beta = pyro.sample("beta", ...) # sample the global RV
    for i in pyro.plate("locals", len(data), subsample_size=5):
        # sample the local RVs
        pyro.sample("z_{}".format(i), ..., lambda_i)

请注意,关键在于索引在 guide 中只会进行一次子采样本;Pyro 后端确保在模型执行期间使用相同的索引集。因此,subsample_size 只需要在 guide 中指定。

均摊

让我们再次考虑一个具有全局和局部隐随机变量以及局部变分参数的模型

\[ p({\bf x}, {\bf z}, \beta) = p(\beta) \prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i | \beta) \qquad \qquad q({\bf z}, \beta) = q(\beta) \prod_{i=1}^N q({\bf z}_i | \beta, \lambda_i)\]

对于中小规模的 \(N\),使用这样的局部变分参数可能是一个很好的方法。然而,如果 \(N\) 很大,我们进行优化的空间随着 \(N\) 增长的事实可能会成为一个真正的问题。避免这种随着数据集大小出现的令人讨厌的增长的一种方法是摊销

其工作原理如下。我们不引入局部变分参数,而是学习一个单一的参数化函数 \(f(\cdot)\),并使用具有以下形式的变分分布

\[q(\beta) \prod_{n=1}^N q({\bf z}_i | f({\bf x}_i))\]

函数 \(f(\cdot)\)(它基本上将给定的观测值映射到一组针对该数据点定制的变分参数)需要足够丰富以准确捕捉后验分布,但现在我们可以处理大型数据集,而无需引入数量巨大的变分参数。这种方法还有其他好处:例如,在学习期间,\(f(\cdot)\)有效地允许我们在不同数据点之间共享统计能力。请注意,这正是 VAE 中使用的方法。

张量形状和向量化 plate

本教程中 pyro.plate 的用法仅限于相对简单的情况。例如,没有一个 plate 嵌套在其他 plate 中。为了充分利用 plate,用户必须小心使用 Pyro 的张量形状语义。有关讨论,请参阅 张量形状教程

参考文献

[1] Stochastic Variational Inference,      Matthew D. Hoffman, David M. Blei, Chong Wang, John Paisley

[2] Auto-Encoding Variational Bayes,     Diederik P Kingma, Max Welling