SVI 第一部分:Pyro 中的随机变分推断介绍

Pyro 的设计特别注重支持随机变分推断 (Stochastic Variational Inference, SVI) 作为一种通用推断算法。让我们看看如何在 Pyro 中进行变分推断。

设置

我们假设已经在 Pyro 中定义了模型(有关如何操作的更多详细信息,请参见Pyro 介绍)。快速回顾一下,模型被定义为一个随机函数 model(*args, **kwargs),在一般情况下它会接收参数。模型 model() 的不同部分通过以下映射编码:

  1. 观测值 \(\Longleftrightarrow\) 带有 obs 参数的 pyro.sample

  2. 隐随机变量 \(\Longleftrightarrow\) pyro.sample

  3. 参数 \(\Longleftrightarrow\) pyro.param

现在让我们建立一些符号。模型有观测值 \({\bf x}\) 和隐随机变量 \({\bf z}\) 以及参数 \(\theta\)。它的联合概率密度形式为

\[p_{\theta}({\bf x}, {\bf z}) = p_{\theta}({\bf x}|{\bf z}) p_{\theta}({\bf z})\]

我们假设构成 \(p_{\theta}({\bf x}, {\bf z})\) 的各种概率分布 \(p_i\) 具有以下性质:

  1. 我们可以从每个 \(p_i\) 中采样

  2. 我们可以计算逐点的对数概率密度函数 \(\log p_i\)

  3. \(p_i\) 关于参数 \(\theta\) 可微

模型学习

在此背景下,我们学习好模型的标准是最大化对数证据 (log evidence),即我们要找到由下式给出的 \(\theta\)

\[\theta_{\rm{max}} = \underset{\theta}{\operatorname{argmax}} \log p_{\theta}({\bf x})\]

其中对数证据 \(\log p_{\theta}({\bf x})\) 由下式给出

\[\log p_{\theta}({\bf x}) = \log \int\! d{\bf z}\; p_{\theta}({\bf x}, {\bf z})\]

在一般情况下,这是一个双重困难的问题。这是因为(即使对于固定的 \(\theta\)),隐随机变量 \(\bf z\) 上的积分通常是难解的。此外,即使我们知道如何计算所有 \(\theta\) 值的对数证据,将对数证据作为 \(\theta\) 的函数最大化通常也是一个困难的非凸优化问题。

除了找到 \(\theta_{\rm{max}}\),我们还想计算隐变量 \(\bf z\) 的后验分布

\[ p_{\theta_{\rm{max}}}({\bf z} | {\bf x}) = \frac{p_{\theta_{\rm{max}}}({\bf x} , {\bf z})}{ \int \! d{\bf z}\; p_{\theta_{\rm{max}}}({\bf x} , {\bf z}) }\]

请注意,此表达式的分母是(通常难解的)证据。变分推断提供了一种寻找 \(\theta_{\rm{max}}\) 并计算后验 \(p_{\theta_{\rm{max}}}({\bf z} | {\bf x})\) 的近似值的方法。让我们看看它是如何工作的。

指南 (Guide)

基本思想是我们引入一个参数化分布 \(q_{\phi}({\bf z})\),其中 \(\phi\) 被称为变分参数 (variational parameters)。这种分布在很多文献中被称为变分分布 (variational distribution),在 Pyro 中它被称为 指南 (guide)(一个音节代替九个!)。指南将作为后验分布的近似。我们可以将 \(\phi\) 视为参数化了一个概率分布的空间或族。我们的目标是找到该空间中(不一定是唯一的)最接近后验分布的概率分布。

就像模型一样,指南也被编码为一个随机函数 guide(),其中包含 pyro.samplepyro.param 语句。它包含观测数据,因为指南需要是一个正确归一化的分布。请注意,Pyro 强制要求 model()guide() 具有相同的调用签名,即两个可调用对象应接受相同的参数。

由于指南是对后验 \(p_{\theta_{\rm{max}}}({\bf z} | {\bf x})\) 的近似,指南需要提供模型中所有隐随机变量的有效联合概率密度。回想一下,当在 Pyro 中使用原始语句 pyro.sample() 指定随机变量时,第一个参数表示随机变量的名称。这些名称将用于对齐模型和指南中的随机变量。为了更明确,如果模型包含一个随机变量 z_1

def model():
    pyro.sample("z_1", ...)

那么指南也需要有一个匹配的 sample 语句

def guide():
    pyro.sample("z_1", ...)

这两种情况下使用的分布可以不同,但名称必须一一对应。

一旦我们指定了指南(下面给出了一些具体示例),就可以进行推断了。学习将被设置为一个优化问题,其中每次训练迭代都会在 \(\theta-\phi\) 空间中迈出一步,使指南更接近精确的后验。为此,我们需要定义一个合适的目标函数。

ELBO

通过简单的推导(例如参见参考文献[1]),我们可以得到我们想要的结果:证据下界 (Evidence Lower Bound, ELBO)。ELBO 是 \(\theta\)\(\phi\) 的函数,它被定义为关于从指南中抽取的样本的期望

\[{\rm ELBO} \equiv \mathbb{E}_{q_{\phi}({\bf z})} \left [ \log p_{\theta}({\bf x}, {\bf z}) - \log q_{\phi}({\bf z}) \right]\]

根据假设,我们可以计算期望中的对数概率。由于指南被假定为可采样的参数分布,我们可以计算此数量的蒙特卡洛估计。至关重要的是,ELBO 是对数证据的下界,即对于 \(\theta\)\(\phi\) 的所有选择,我们都有

\[\log p_{\theta}({\bf x}) \ge {\rm ELBO}\]

因此,如果我们采取(随机)梯度步骤来最大化 ELBO,我们也将在(期望上)提高对数证据。此外,可以证明 ELBO 和对数证据之间的差距由指南与后验之间的 Kullback-Leibler 散度(KL 散度)给出

\[ \log p_{\theta}({\bf x}) - {\rm ELBO} = \rm{KL}\!\left( q_{\phi}({\bf z}) \lVert p_{\theta}({\bf z} | {\bf x}) \right)\]

这个 KL 散度是衡量两个分布“接近程度”的一种特定(非负)度量。因此,对于固定的 \(\theta\),当我们在 \(\phi\) 空间中采取增加 ELBO 的步骤时,我们减小了指南与后验之间的 KL 散度,即我们将指南移向后验。在一般情况下,我们同时在 \(\theta\)\(\phi\) 空间中采取梯度步骤,使得指南和模型相互追逐,指南跟踪移动的后验 \(\log p_{\theta}({\bf z} | {\bf x})\)。或许有些令人惊讶的是,尽管目标在移动,但对于许多不同的问题,这个优化问题是可以解决的(达到适当的近似水平)。

因此,从高层次上看,变分推断很容易:我们只需要定义一个指南并计算 ELBO 的梯度。实际上,计算一般模型和指南对的梯度会导致一些复杂性(请参见教程SVI 第三部分中的讨论)。就本教程而言,我们假设这是一个已解决的问题,并看看 Pyro 为进行变分推断提供的支持。

SVI

在 Pyro 中,用于进行变分推断的机制封装在 `SVI <https://docs.pyro.org.cn/en/stable/inference_algos.html?highlight=svi>`__ 类中。

用户需要提供三样东西:模型、指南和优化器。我们上面已经讨论了模型和指南,下面将详细讨论优化器,所以我们假设手头已经有了这三样东西。要构建一个通过 ELBO 目标执行优化的 SVI 实例,用户可以这样写:

import pyro
from pyro.infer import SVI, Trace_ELBO
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

SVI 对象提供了两个方法,step()evaluate_loss(),它们封装了变分学习和评估的逻辑。

  1. 方法 step() 采取单个梯度步,并返回损失(即负 ELBO)的估计值。如果提供参数,step() 的参数会传递给 model()guide()

  2. 方法 evaluate_loss() 返回损失的估计值,但采取梯度步。与 step() 一样,如果提供参数,evaluate_loss() 的参数会传递给 model()guide()

当损失是 ELBO 时,这两种方法都接受一个可选参数 num_particles,表示用于计算损失(对于 evaluate_loss)以及损失和梯度(对于 step)的样本数量。

优化器

在 Pyro 中,模型和指南允许是任意随机函数,只要满足以下条件:

  1. guide 不包含带有 obs 参数的 pyro.sample 语句

  2. modelguide 具有相同的调用签名

这带来了一些挑战,因为这意味着 model()guide() 的不同执行可能会有相当不同的行为,例如某些隐随机变量和参数可能只在部分时间出现。实际上,参数可能会在推断过程中动态创建。换句话说,我们进行优化的空间(由 \(\theta\)\(\phi\) 参数化)可以动态增长和变化。

为了支持这种行为,Pyro 需要在学习过程中第一次看到每个参数时动态为其生成一个优化器。幸运的是,PyTorch 有一个轻量级的优化库(参见torch.optim),可以轻松地为动态情况重新利用。

所有这些都由 `optim.PyroOptim <https://docs.pyro.org.cn/en/stable/optimization.html?highlight=optim#pyro.optim.optim.PyroOptim>`__ 类控制,该类基本上是 PyTorch 优化器的一个薄包装。 PyroOptim 接受两个参数:PyTorch 优化器的构造函数 optim_constructor 和优化器参数的规范 optim_args。在高层次上,在优化过程中,每当看到一个新的参数时,就会使用 optim_constructor 根据 optim_args 给定的参数实例化一个给定类型的新优化器。

大多数用户可能不会直接与 PyroOptim 交互,而是与 optim/__init__.py 中定义的别名交互。让我们看看如何使用。有两种方法可以指定优化器参数。在更简单的情况下,optim_args 是一个固定的字典,指定用于实例化所有参数的 PyTorch 优化器的参数。

from pyro.optim import Adam

adam_params = {"lr": 0.005, "betas": (0.95, 0.999)}
optimizer = Adam(adam_params)

第二种指定参数的方式可以实现更精细的控制。在这里,用户必须指定一个可调用对象,当 Pyro 为新看到的参数创建优化器时,该对象将被调用。该可调用对象必须接受参数的 Pyro 名称作为参数。这使得用户能够,例如,为不同的参数自定义学习率。有关此级别控制有用的示例,请参见基线讨论。这是一个说明 API 的简单示例:

from pyro.optim import Adam

def per_param_callable(param_name):
    if param_name == 'my_special_parameter':
        return {"lr": 0.010}
    else:
        return {"lr": 0.001}

optimizer = Adam(per_param_callable)

这只是告诉 Pyro 对于 Pyro 参数 my_special_parameter 使用学习率 0.010,对于所有其他参数使用学习率 0.001

一个简单的例子

最后我们来看一个简单的例子。你拿到了一枚双面硬币。你想确定这枚硬币是否公平,即它出现正面或反面的频率是否相同。你对硬币的公平性有一个基于两个观察的先验信念:

  • 它是一枚由美国铸币局发行的标准 25 美分硬币

  • 由于多年的使用,它有些磨损

因此,虽然你预计硬币刚生产时相当公平,但你允许其公平性此后偏离完美的 1:1 比例。所以如果硬币出现正面与反面的比例为 11:10,你不会感到惊讶。相比之下,如果硬币出现正面与反面的比例为 5:1,你会非常惊讶——它没那么磨损。

为了将其转化为概率模型,我们将正面和反面编码为 10。我们将硬币的公平性编码为一个实数 \(f\),其中 \(f\) 满足 \(f \in [0.0, 1.0]\)\(f=0.50\) 对应于完全公平的硬币。我们对 \(f\) 的先验信念将通过 Beta 分布编码,具体来说是 \(\rm{Beta}(10,10)\),这是一个在区间 \([0.0, 1.0]\) 上的对称概率分布,峰值在 \(f=0.5\)

图 1:编码我们对硬币公平性先验信念的 Beta 分布。

为了学习比我们模糊的先验更精确的硬币公平性信息,我们需要做一次实验并收集一些数据。假设我们抛硬币 10 次并记录每次抛掷的结果。实际上,我们可能想做不止 10 次试验,但这只是一个教程。

假设我们将数据收集在一个列表 data 中,相应的模型由下式给出:

import pyro.distributions as dist

def model(data):
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the Bernoulli
        # likelihood Bernoulli(f)
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

这里我们有一个隐随机变量('latent_fairness'),它服从 \(\rm{Beta}(10, 10)\) 分布。以此随机变量为条件,我们使用伯努利似然观测每个数据点。注意,每个观测值在 Pyro 中都被分配了一个唯一的名称。

我们的下一个任务是定义相应的指南,即隐随机变量 \(f\) 的合适变分分布。这里唯一的真正要求是 \(q(f)\) 应该是一个在区间 \([0.0, 1.0]\) 上的概率分布,因为 \(f\) 在该区间之外没有意义。一个简单的选择是使用另一个由两个可训练参数 \(\alpha_q\)\(\beta_q\) 参数化的 Beta 分布。实际上,在这种特定情况下,这是“正确”的选择,因为伯努利分布和 Beta 分布的共轭性意味着精确后验是 Beta 分布。在 Pyro 中我们写

def guide(data):
    # register the two variational parameters with Pyro.
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

这里有几点需要注意:

  • 我们注意确保模型和指南中随机变量的名称完全对齐。

  • model(data)guide(data) 接受相同的参数。

  • 变分参数是 torch.tensorrequires_grad 标志由 pyro.param 自动设置为 True

  • 我们使用 constraint=constraints.positive 来确保 alpha_qbeta_q 在优化过程中保持非负。在内部,指数变换确保了正性。

现在我们可以进行随机变分推断了。

# set up the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 5000
# do gradient steps
for step in range(n_steps):
    svi.step(data)

注意,在 step() 方法中,我们传入了数据,这些数据随后会传递给模型和指南。

此时我们唯一缺少的是一些数据。因此,让我们创建一些数据并将上面的所有代码片段组合成一个完整的脚本:

[ ]:
import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
n_steps = 2 if smoke_test else 2000

assert pyro.__version__.startswith('1.9.1')

# clear the param store in case we're in a REPL
pyro.clear_param_store()

# create some data with 6 observed heads and 4 observed tails
data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))

def model(data):
    # define the hyperparameters that control the Beta prior
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    # sample f from the Beta prior
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    # loop over the observed data
    for i in range(len(data)):
        # observe datapoint i using the ernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    # register the two variational parameters with Pyro
    # - both parameters will have initial value 15.0.
    # - because we invoke constraints.positive, the optimizer
    # will take gradients on the unconstrained parameters
    # (which are related to the constrained parameters by a log)
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

# setup the optimizer
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# do gradient steps
for step in range(n_steps):
    svi.step(data)
    if step % 100 == 0:
        print('.', end='')

# grab the learned variational parameters
alpha_q = pyro.param("alpha_q").item()
beta_q = pyro.param("beta_q").item()

# here we use some facts about the Beta distribution
# compute the inferred mean of the coin's fairness
inferred_mean = alpha_q / (alpha_q + beta_q)
# compute inferred standard deviation
factor = beta_q / (alpha_q * (1.0 + alpha_q + beta_q))
inferred_std = inferred_mean * math.sqrt(factor)

print("\nBased on the data and our prior belief, the fairness " +
      "of the coin is %.3f +- %.3f" % (inferred_mean, inferred_std))

示例输出:

based on the data and our prior belief, the fairness of the coin is 0.532 +- 0.090

这个估计值应与精确后验均值进行比较,在本例中精确后验均值是 \(16/30 = 0.5\bar{3}\)。注意,对硬币公平性的最终估计值介于先验偏好的公平性(即 \(0.50\))和原始经验频率暗示的公平性(\(6/10 = 0.60\))之间。

参考文献

[1] Probabilistic Programming 中的自动化变分推断,      David Wingate, Theo Weber

[2] 黑箱变分推断,     Rajesh Ranganath, Sean Gerrish, David M. Blei

[3] 自编码变分贝叶斯,     Diederik P Kingma, Max Welling

[4] 变分推断:面向统计学家的综述,     David M. Blei, Alp Kucukelbir, Jon D. McAuliffe