Poutine:Pyro 中使用 Effect Handler 编程指南

致读者:本教程是关于 Pyro 的 effect handling 库 Poutine 的 API 细节指南。我们建议读者首先熟悉简化的 minipyro.py,其中包含 Pyro 运行时和此处描述的 effect handler 抽象的最小化、可读实现。Pyro 的 effect handler 库比 minipyro 更通用,但也包含更多间接层;建议将它们并行阅读。

[1]:
import torch

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

from pyro.poutine.runtime import effectful

pyro.set_rng_seed(101)

介绍

概率编程中的推断涉及操作或转换写成生成模型的概率程序。例如,几乎所有近似推断算法都需要计算生成模型下潜在变量和观测变量值的未归一化联合概率。

考虑以下示例模型

[2]:
def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

此模型定义了 "weight""measurement" 的联合概率分布

\[ 重量 \, | \, 猜测 \sim 正态(猜测, 1) \]
\[{\sf measurement} \, | \, {\sf guess}, {\sf weight} \sim {\sf Normal}({\sf weight}, 0.75)\]

如果我们可以访问每个 pyro.sample site 的输入和输出,我们可以计算它们的 log-joint

logp = dist.Normal(guess, 1.0).log_prob(weight).sum() + dist.Normal(weight, 0.75).log_prob(measurement).sum()

然而,上面我们编写 scale 的方式似乎没有暴露这些中间分布对象,并且重写它以返回它们将是侵入性的,并且会违反像 Pyro 这样的概率编程语言旨在强制执行的模型和推断算法之间的关注点分离。

为了解决这个冲突并促进推断算法的开发,Pyro 提供了 Poutine,这是一个 effect handler 库,或用于检查和修改 Pyro 程序行为的可组合构建块。Pyro 的大部分内部机制都是在 Poutine 之上实现的。

初探 Poutine:Pyro 的算法构建模块库

Effect handler 是编程语言社区中常见的抽象,它对编程语言中特定语句(如 pyro.samplepyro.param)的行为赋予了 非标准解释副作用。有关编程语言研究中 effect handler 的背景阅读,请参阅本教程末尾的可选“参考文献”部分。

与其回顾更多定义,不如看看一个解决上述问题的第一个示例:我们可以组合两个现有的 effect handler,poutine.condition(设置 pyro.sample 语句的输出值)和 poutine.trace(记录 pyro.sample 语句的输入、分布和输出),来简洁地定义一个新的 effect handler,用于计算 log-joint

[3]:
def make_log_joint(model):
    def _log_joint(cond_data, *args, **kwargs):
        conditioned_model = poutine.condition(model, data=cond_data)
        trace = poutine.trace(conditioned_model).get_trace(*args, **kwargs)
        return trace.log_prob_sum()
    return _log_joint

scale_log_joint = make_log_joint(scale)
print(scale_log_joint({"measurement": torch.tensor(9.5), "weight": torch.tensor(8.23)}, torch.tensor(8.5)))
tensor(-3.0203)

那个片段很短,但仍有些晦涩——poutine.conditionpoutine.tracetrace.log_prob_sum 都是黑箱。让我们从 poutine.conditionpoutine.trace 中移除一层样板代码,并明确实现 trace.log_prob_sum 的功能

[4]:
from pyro.poutine.trace_messenger import TraceMessenger
from pyro.poutine.condition_messenger import ConditionMessenger

def make_log_joint_2(model):
    def _log_joint(cond_data, *args, **kwargs):
        with TraceMessenger() as tracer:
            with ConditionMessenger(data=cond_data):
                model(*args, **kwargs)

        trace = tracer.trace
        logp = 0.
        for name, node in trace.nodes.items():
            if node["type"] == "sample":
                if node["is_observed"]:
                    assert node["value"] is cond_data[name]
                logp = logp + node["fn"].log_prob(node["value"]).sum()
        return logp
    return _log_joint

scale_log_joint = make_log_joint_2(scale)
print(scale_log_joint({"measurement": 9.5, "weight": 8.23}, 8.5))
tensor(-3.0203)

这使得事情稍微清晰了一些:我们现在可以看到 poutine.tracepoutine.condition 是上下文管理器的封装器,它们可能通过 pyro.sample 内部的某些东西与模型通信。我们还可以看到 poutine.trace 生成一个数据结构(一个 Trace),其中包含一个字典,其键是 sample site 的名称,值是包含每个 site 的分布 ("fn"`) 和输出 ("value"`) 的字典,并且每个 site 的输出值正是 data 中指定的值。

最后,TraceMessengerConditionMessenger 是 Pyro 的 effect handler,或者说是 Messenger:它们是状态化的上下文管理器对象,被放置在全局栈上,并在每次 effectful 操作(例如调用 pyro.sample)时在栈中向上和向下发送消息(因此得名)。当调用其 __enter__ 方法时,即在“with”语句中使用时,Messenger 会被放置在栈的底部。

我们将在本教程稍后更详细地介绍这个过程。有关仅几行代码的简化实现,请参阅 pyro.contrib.minipyro

使用 Messenger API 实现新的 effect handler

虽然通过组合 pyro.poutine 中现有的 effect handler 来构建新的 effect handler 是最容易的,但将新 effect 实现为 pyro.poutine.messenger.Messenger 子类实际上也相当直接。在深入探讨 API 之前,让我们看另一个例子:一个在我们模型执行期间进行求和的 log-joint 计算版本。然后我们将回顾示例的每个部分实际上在做什么。

[5]:
class LogJointMessenger(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    # __call__ is syntactic sugar for using Messengers as higher-order functions.
    # Messenger already defines __call__, but we re-define it here
    # for exposition and to change the return value:
    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__enter__()
        # in their __enter__ methods
        return super().__enter__()

    # __exit__ takes the same arguments in all Python context managers
    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        # All Messenger subclasses must call the base Messenger.__exit__ method
        # in their __exit__ methods.
        return super().__exit__(exc_type, exc_value, traceback)

    # _pyro_sample will be called once per pyro.sample site.
    # It takes a dictionary msg containing the name, distribution,
    # observation or sample value, and other metadata from the sample site.
    def _pyro_sample(self, msg):
        # Any unobserved random variables will trigger this assertion.
        # In the next section, we'll learn how to also handle sampled values.
        assert msg["name"] in self.data
        msg["value"] = self.data[msg["name"]]
        # Since we've observed a value for this site, we set the "is_observed" flag to True
        # This tells any other Messengers not to overwrite msg["value"] with a sample.
        msg["is_observed"] = True
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()

with LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp.clone())

scale_log_joint = LogJointMessenger(cond_data={"measurement": 9.5, "weight": 8.23})(scale)
print(scale_log_joint(8.5))
tensor(-3.0203)
tensor(-3.0203)

一段方便的样板代码,允许将 LogJointMessenger 用作上下文管理器、装饰器或高阶函数,如下所示。Pyro 中大多数现有的 effect handler,包括我们之前使用的 poutine.tracepoutine.condition,都是在 pyro.poutine.handlers 中以这种方式封装的 Messenger`.`

[6]:
def log_joint(model=None, cond_data=None):
    msngr = LogJointMessenger(cond_data=cond_data)
    return msngr(model) if model is not None else msngr

scale_log_joint = log_joint(scale, cond_data={"measurement": 9.5, "weight": 8.23})
print(scale_log_joint(8.5))
tensor(-3.0203)

Messenger API 详解

我们的 LogJointMessenger 实现有三个重要方法:__enter____exit___pyro_sample`.`

__enter____exit__ 是任何 Python 上下文管理器所需的特殊方法。在实现新的 Messenger 类时,如果重写 __enter____exit__,我们总是需要调用基类 Messenger__enter____exit__ 方法,以便正确应用新的 Messenger`.`

最后一个方法 LogJointMessenger._pyro_sample,会在每个 sample site 调用一次。它读取并修改一个 消息,消息是一个字典,包含 sample site 的名称、分布、采样或观测值以及其他元数据。我们将在下一节更详细地检查消息的内容。

实际上,一个通用的 Messenger 不是使用 _pyro_sample,而是包含两个方法,它们在每次执行副作用的操作时调用一次:1. _process_message 修改消息并将结果发送给栈中紧邻其上的 Messenger` 2. _postprocess_message 修改消息并将结果发送给栈中紧邻其下的 Messenger`。它总是在所有激活的 Messenger 对消息应用其 _process_message 方法之后调用。

尽管自定义 Messenger 可以重写 _process_message_postprocess_message,但方便的做法是避免要求所有 effect handler 都了解所有可能的 effectful 操作类型。因此,默认情况下,Messenger._process_message 会使用 msg["type"] 派发到相应的 Messenger._pyro_<type> 方法,例如 LogJointMessenger 中的 Messenger._pyro_sample。就像异常处理代码忽略未处理的异常类型一样,这使得 Messenger 可以简单地将它们不知道如何处理的操作转发到栈中的下一个 Messenger`.`

class Messenger:
    ...
    def _process_message(self, msg):
        method_name = "_pyro_{}".format(msg["type"])  # e.g. _pyro_sample when msg["type"] == "sample"
        if hasattr(self, method_name):
            getattr(self, method_name)(msg)
    ...

插曲:全局 Messenger

有关本节机制的端到端实现,请参阅 pyro.contrib.minipyro

Messenger 应用于 pyro.sample 语句等操作的顺序取决于它们的 __enter__ 方法调用的顺序。Messenger.__enter__ 会将一个 Messenger 添加到全局 handler 栈的末尾(底部)。

class Messenger:
    ...
    # __enter__ pushes a Messenger onto the stack
    def __enter__(self):
        ...
        _PYRO_STACK.append(self)
        ...

    # __exit__ removes a Messenger from the stack
    def __exit__(self, ...):
        ...
        assert _PYRO_STACK[-1] is self
        _PYRO_STACK.pop()
        ...

pyro.poutine.runtime.apply_stack 然后在每次操作时遍历栈两次,首先从底部到顶部应用每个 _process_message,然后从顶部到底部应用每个 _postprocess_message`.`

def apply_stack(msg):  # simplified
    for handler in reversed(_PYRO_STACK):
        handler._process_message(msg)
    ...
    default_process_message(msg)
    ...
    for handler in _PYRO_STACK:
        handler._postprocess_message(msg)
    ...
    return msg

回到 LogJointMessenger 示例

第二个方法 _postprocess_message 是必需的,因为有些 effect 只能在所有其他 effect handler 有机会更新消息一次之后应用。在 LogJointMessenger 的情况下,其他 effect(如枚举)可能会修改 sample site 的值或分布(msg["value"]msg["fn"]`),因此我们将 log-probability 计算移到一个新方法 _pyro_post_sample 中,该方法在所有激活 handler 的 _pyro_sample 方法应用后,由 _postprocess_message 在每个 sample site 调用(通过类似 _process_message 使用的派发机制)。`

[7]:
class LogJointMessenger2(poutine.messenger.Messenger):

    def __init__(self, cond_data):
        self.data = cond_data

    def __call__(self, fn):
        def _fn(*args, **kwargs):
            with self:
                fn(*args, **kwargs)
                return self.logp.clone()
        return _fn

    def __enter__(self):
        self.logp = torch.tensor(0.)
        return super().__enter__()

    def __exit__(self, exc_type, exc_value, traceback):
        self.logp = torch.tensor(0.)
        return super().__exit__(exc_type, exc_value, traceback)

    def _pyro_sample(self, msg):
        if msg["name"] in self.data:
            msg["value"] = self.data[msg["name"]]
            msg["done"] = True

    def _pyro_post_sample(self, msg):
        assert msg["done"]  # the "done" flag asserts that no more modifications to value and fn will be performed.
        self.logp = self.logp + (msg["scale"] * msg["fn"].log_prob(msg["value"])).sum()


with LogJointMessenger2(cond_data={"measurement": 9.5, "weight": 8.23}) as m:
    scale(8.5)
    print(m.logp)
tensor(-3.0203)

Messenger 发送的消息内部结构

正如前两个示例所述,在栈中向上和向下发送的实际消息是带有特定键集的字典。考虑以下 sample 语句

pyro.sample("x", dist.Bernoulli(0.5), infer={"enumerate": "parallel"}, obs=None)

此 sample 语句在应用任何 effect 之前转换为初始消息,每个 effect handler 的 _process_message_postprocess_message 可以就地更新字段或添加新字段。为了完整起见,我们在此列出完整的初始消息

msg = {
    # The following fields contain the name, inputs, function, and output of a site.
    # These are generally the only fields you'll need to think about.
    "name": "x",
    "fn": dist.Bernoulli(0.5),
    "value": None,  # msg["value"] will eventually contain the value returned by pyro.sample
    "is_observed": False,  # because obs=None by default; only used by sample sites
    "args": (),  # positional arguments passed to "fn" when it is called; usually empty for sample sites
    "kwargs": {},  # keyword arguments passed to "fn" when it is called; usually empty for sample sites
    # This field typically contains metadata needed or stored by a particular inference algorithm
    "infer": {"enumerate": "parallel"},
    # The remaining fields are generally only used by Pyro's internals,
    # or for implementing more advanced effects beyond the scope of this tutorial
    "type": "sample",  # label used by Messenger._process_message to dispatch, in this case to _pyro_sample
    "done": False,
    "stop": False,
    "scale": torch.tensor(1.),  # Multiplicative scale factor that can be applied to each site's log_prob
    "mask": None,
    "continuation": None,
    "cond_indep_stack": (),  # Will contain metadata from each pyro.plate enclosing this sample site.
}

请注意,当我们像前两个版本的 make_log_joint 那样使用 poutine.traceTraceMessenger 时,msg 的内容正是 trace 中为每个 sample 和 param site 存储的信息。

使用现有 effect handler 实现推断算法:示例

事实证明,许多推断操作,例如上面我们的第一个版本的 make_log_joint,使用 pyro.poutine 中现有的 effect handler 实现时出奇地短。

示例:使用蒙特卡洛 ELBO 进行变分推断

例如,这里是使用蒙特卡洛 ELBO 进行变分推断的实现,它使用了 poutine.trace`、poutine.condition` 和 poutine.replay`。这与 pyro.contrib.minipyro 中的简单 ELBO 非常相似。`

[8]:
def monte_carlo_elbo(model, guide, batch, *args, **kwargs):
    # assuming batch is a dictionary, we use poutine.condition to fix values of observed variables
    conditioned_model = poutine.condition(model, data=batch)

    # we'll approximate the expectation in the ELBO with a single sample:
    # first, we run the guide forward unmodified and record values and distributions
    # at each sample site using poutine.trace
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)

    # we use poutine.replay to set the values of latent variables in the model
    # to the values sampled above by our guide, and use poutine.trace
    # to record the distributions that appear at each sample site in in the model
    model_trace = poutine.trace(
        poutine.replay(conditioned_model, trace=guide_trace)
    ).get_trace(*args, **kwargs)

    elbo = 0.
    for name, node in model_trace.nodes.items():
        if node["type"] == "sample":
            elbo = elbo + node["fn"].log_prob(node["value"]).sum()
            if not node["is_observed"]:
                elbo = elbo - guide_trace.nodes[name]["fn"].log_prob(node["value"]).sum()
    return -elbo

我们使用 poutine.tracepoutine.block 来记录用于优化的 pyro.param 调用

[9]:
def train(model, guide, data):
    optimizer = pyro.optim.Adam({})
    for batch in data:
        # this poutine.trace will record all of the parameters that appear in the model and guide
        # during the execution of monte_carlo_elbo
        with poutine.trace() as param_capture:
            # we use poutine.block here so that only parameters appear in the trace above
            with poutine.block(hide_fn=lambda node: node["type"] != "param"):
                loss = monte_carlo_elbo(model, guide, batch)

        loss.backward()
        params = set(node["value"].unconstrained()
                     for node in param_capture.trace.nodes.values())
        optimizer.step(params)
        pyro.infer.util.zero_grads(params)

示例:通过序列枚举进行精确推断

这里是一个使用 pyro.poutine 实现的完全不同的推断算法——通过枚举进行精确推断的示例。该算法的完整解释超出了本教程的范围,可以在这本简短的在线书籍 概率编程语言的设计与实现 的第三章中找到。此示例使用 poutine.queue,它本身是使用 poutine.trace`、poutine.replay` 和 poutine.block` 实现的,用于枚举模型中所有离散变量的可能值,并计算所有可能返回值或特定 sample site 可能值的边际分布。`

[10]:
def sequential_discrete_marginal(model, data, site_name="_RETURN"):

    from six.moves import queue  # queue data structures
    q = queue.Queue()  # Instantiate a first-in first-out queue
    q.put(poutine.Trace())  # seed the queue with an empty trace

    # as before, we fix the values of observed random variables with poutine.condition
    # assuming data is a dictionary whose keys are names of sample sites in model
    conditioned_model = poutine.condition(model, data=data)

    # we wrap the conditioned model in a poutine.queue,
    # which repeatedly pushes and pops partially completed executions from a Queue()
    # to perform breadth-first enumeration over the set of values of all discrete sample sites in model
    enum_model = poutine.queue(conditioned_model, queue=q)

    # actually perform the enumeration by repeatedly tracing enum_model
    # and accumulate samples and trace log-probabilities for postprocessing
    samples, log_weights = [], []
    while not q.empty():
        trace = poutine.trace(enum_model).get_trace()
        samples.append(trace.nodes[site_name]["value"])
        log_weights.append(trace.log_prob_sum())

    # we take the samples and log-joints and turn them into a histogram:
    samples = torch.stack(samples, 0)
    log_weights = torch.stack(log_weights, 0)
    log_weights = log_weights - dist.util.logsumexp(log_weights, dim=0)
    return dist.Empirical(samples, log_weights)

(注意,sequential_discrete_marginal 非常通用,但速度也相当慢。对于适用于不太通用模型类别的高性能并行枚举,请参阅枚举教程。)

示例:使用 Messenger API 实现惰性求值

现在我们已经更多地了解了 Messenger 的内部机制,让我们用它来实现一个稍微复杂的 effect:惰性求值。我们首先定义一个 LazyValue 类,我们将用它来构建计算图

[11]:
class LazyValue:
    def __init__(self, fn, *args, **kwargs):
        self._expr = (fn, args, kwargs)
        self._value = None

    def __str__(self):
        return "({} {})".format(str(self._expr[0]), " ".join(map(str, self._expr[1])))

    def evaluate(self):
        if self._value is None:
            fn, args, kwargs = self._expr
            fn = fn.evaluate() if isinstance(fn, LazyValue) else fn
            args = tuple(arg.evaluate() if isinstance(arg, LazyValue) else arg
                         for arg in args)
            kwargs = {k: v.evaluate() if isinstance(v, LazyValue) else v
                      for k, v in kwargs.items()}
            self._value = fn(*args, **kwargs)
        return self._value

有了 LazyValue,将惰性求值实现为与其他 effect handler 兼容的 Messenger 惊人地容易。我们只需将每个 msg["value"] 变成一个 LazyValue,并为确定性操作引入一个新的操作类型 "apply"`.`

[12]:
class LazyMessenger(pyro.poutine.messenger.Messenger):
    def _process_message(self, msg):
        if msg["type"] in ("apply", "sample") and not msg["done"]:
            msg["done"] = True
            msg["value"] = LazyValue(msg["fn"], *msg["args"], **msg["kwargs"])

最后,就像 torch.autograd 重载 torch tensor 操作以记录 autograd 图一样,我们需要包装任何我们希望惰性化的操作。我们将使用 pyro.poutine.runtime.effectful 作为装饰器,将这些操作暴露给 LazyMessenger`。effectful 构建一个与上面非常相似的消息,并在 effect handler 栈中向上和向下发送,但允许我们设置类型(在本例中设置为 "apply" 而不是 "sample"`),这样这些操作就不会被 TraceMessenger 等其他 effect handler 误认为是 sample 语句。

[13]:
@effectful(type="apply")
def add(x, y):
    return x + y

@effectful(type="apply")
def mul(x, y):
    return x * y

@effectful(type="apply")
def sigmoid(x):
    return torch.sigmoid(x)

@effectful(type="apply")
def normal(loc, scale):
    return dist.Normal(loc, scale)

应用于另一个模型

[14]:
def biased_scale(guess):
    weight = pyro.sample("weight", normal(guess, 1.))
    tolerance = pyro.sample("tolerance", normal(0., 0.25))
    return pyro.sample("measurement", normal(add(mul(weight, 0.8), 1.), sigmoid(tolerance)))

with LazyMessenger():
    v = biased_scale(8.5)
    print(v)
    print(v.evaluate())
((<function normal at 0x7fc41cbfdc80> (<function add at 0x7fc41cbf91e0> (<function mul at 0x7fc41cbfda60> ((<function normal at 0x7fc41cbfdc80> 8.5 1.0) ) 0.8) 1.0) (<function sigmoid at 0x7fc41cbfdb70> ((<function normal at 0x7fc41cbfdc80> 0.0 0.25) ))) )
tensor(6.5436)

TraceMessengerConditionMessenger 等其他可以自由组合的 effect handler 一起,LazyMessenger 展示了如何使用 Poutine 快速简洁地实现最先进的 PPL 技术,例如 使用 Rao-Blackwellization 的延迟采样

参考文献:编程语言研究中的代数效应和 handler

本节包含一些供对此方向感兴趣的读者的 PL 论文参考文献。

代数效应和 handler 从 21 世纪初开始发展,并且是编程语言社区中活跃研究的主题,它们是一种通用的抽象,用于构建编程语言中特定语句(如 pyro.samplepyro.param)的非标准解释器的模块化实现。它们最初被引入是为了解决使用 monad 和 monad transformer 实现非标准解释器的组合困难。

  • 对于 effect handler 文献的易懂介绍,请参阅 Ohad Kammar、Sam Lindley 和 Nicolas Oury 合著的优秀综述/教程论文 “Handlers in Action”,以及其中的参考文献。

  • 代数 effect handler 最初由 Gordon Plotkin 和 Matija Pretnar 在论文 “Handlers of Algebraic Effects” 中引入。

  • effect handler 的一个有用心智模型是将其视为异常 handler,它能够在引发异常并在 except 块中执行一些处理后,恢复 try 块中的计算。这个比喻在实验性编程语言 Eff 及其伴随论文 “Programming with Algebraic Effects and Handlers”(作者 Andrej Bauer 和 Matija Pretnar)中得到了进一步探讨。

  • Pyro 中的大多数 effect handler 是“线性的”,这意味着它们在每次 effectful 操作中只恢复一次,并且不改变原始程序的执行顺序。一个例外是 poutine.queue,它使用了一种效率较低的多重恢复实现策略,例如 James Koppel、Gabriel Scherer 和 Armando Solar-Lezama 在论文 “Capturing the Future by Replaying the Past” 中描述的 delimited continuations。

  • 在 Python 或 JavaScript 等主流编程语言中实现 effect handler 的更有效策略是一个活跃的研究领域。一个有前景的研究方向涉及选择性 continuation-passing style 转换,如 Daan Leijen 在论文 “Type-Directed Compilation of Row-Typed Algebraic Effects” 中所述。