Pyro 模型的自动渲染

在本教程中,我们将演示如何使用 pyro.render_model() 为您的概率图模型创建精美的可视化。

[1]:
import os
import torch
import torch.nn.functional as F
import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')

一个简单示例

可视化接口可以轻松地与您的模型一起使用

[2]:
def model(data):
    m = pyro.sample("m", dist.Normal(0, 1))
    sd = pyro.sample("sd", dist.LogNormal(m, 1))
    with pyro.plate("N", len(data)):
        pyro.sample("obs", dist.Normal(m, sd), obs=data)
[3]:
data = torch.ones(10)
pyro.render_model(model, model_args=(data,))
[3]:
_images/model_rendering_4_0.svg

通过在调用 pyro.render_model 时提供 filename='path',可以将可视化保存到文件中。您可以通过更改文件名的后缀来使用不同的格式,例如 PDF 或 PNG。当不保存到文件时 (filename=None),您还可以使用 graph.format = 'pdf' 来更改格式,其中 graphpyro.render_model 返回的对象。

[4]:
graph = pyro.render_model(model, model_args=(data,), filename="model.pdf")

调整可视化

由于 pyro.render_model 返回一个 graphviz.dot.Digraph 类型的对象,您可以进一步改进此图的可视化效果。例如,您可以使用 unflatten 预处理器 来改进更复杂模型的布局长宽比。

[5]:
def mace(positions, annotations):
    """
    This model corresponds to the plate diagram in Figure 3 of https://www.aclweb.org/anthology/Q18-1040.pdf.
    """
    num_annotators = int(torch.max(positions)) + 1
    num_classes = int(torch.max(annotations)) + 1
    num_items, num_positions = annotations.shape

    with pyro.plate("annotator", num_annotators):
        epsilon = pyro.sample("ε", dist.Dirichlet(torch.full((num_classes,), 10.)))
        theta = pyro.sample("θ", dist.Beta(0.5, 0.5))

    with pyro.plate("item", num_items, dim=-2):
        # NB: using constant logits for discrete uniform prior
        # (NumPyro does not have DiscreteUniform distribution yet)
        c = pyro.sample("c", dist.Categorical(logits=torch.zeros(num_classes)))

        with pyro.plate("position", num_positions):
            s = pyro.sample("s", dist.Bernoulli(1 - theta[positions]))
            probs = torch.where(
                s[..., None] == 0, F.one_hot(c, num_classes).float(), epsilon[positions]
            )
            pyro.sample("y", dist.Categorical(probs), obs=annotations)


positions = torch.tensor([1, 1, 1, 2, 3, 4, 5])
# fmt: off
annotations = torch.tensor([
    [1, 3, 1, 2, 2, 2, 1, 3, 2, 2, 4, 2, 1, 2, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 2, 4, 2, 2, 3, 1, 1, 1, 2, 1, 2],
    [1, 3, 1, 2, 2, 2, 2, 3, 2, 3, 4, 2, 1, 2, 2,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 3, 1, 1, 1,
     1, 3, 1, 2, 2, 3, 2, 3, 3, 1, 1, 2, 3, 2, 2],
    [1, 3, 2, 2, 2, 2, 2, 3, 2, 2, 4, 2, 1, 2, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 2, 1, 1, 2,
     1, 3, 1, 2, 2, 3, 1, 2, 3, 1, 1, 1, 2, 1, 2],
    [1, 4, 2, 3, 3, 3, 2, 3, 2, 2, 4, 3, 1, 3, 1,
     2, 1, 1, 2, 1, 2, 2, 3, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 3, 4, 2, 3, 3, 1, 1, 2, 2, 1, 2],
    [1, 3, 1, 1, 2, 3, 1, 4, 2, 2, 4, 3, 1, 2, 1,
     1, 1, 1, 2, 3, 2, 2, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 2, 1, 2, 2, 3, 2, 2, 4, 1, 1, 1, 2, 1, 2],
    [1, 3, 2, 2, 2, 2, 1, 3, 2, 2, 4, 4, 1, 1, 1,
     1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 1, 1, 2,
     1, 3, 1, 2, 3, 4, 3, 3, 3, 1, 1, 1, 2, 1, 2],
    [1, 4, 2, 1, 2, 2, 1, 3, 3, 3, 4, 3, 1, 2, 1,
     1, 1, 1, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1,
     1, 3, 1, 2, 2, 3, 2, 3, 2, 1, 1, 1, 2, 1, 2],
]).T
# fmt: on

# we subtract 1 because the first index starts with 0 in Python
positions -= 1
annotations -= 1

mace_graph = pyro.render_model(mace, model_args=(positions, annotations))
[6]:
# default layout
mace_graph
[6]:
_images/model_rendering_9_0.svg
[7]:
# layout after processing the layout with unflatten
mace_graph.unflatten(stagger=2)
[7]:
_images/model_rendering_10_0.svg

渲染参数

通过在 pyro.render_model 中设置 render_params=True,我们可以渲染定义为 pyro.param 的参数。

[8]:
def model(data):
    sigma = pyro.param("sigma", torch.tensor([1.]), constraint=constraints.positive)
    mu = pyro.param("mu", torch.tensor([0.]))
    x = pyro.sample("x", dist.Normal(mu, sigma))
    y = pyro.sample("y", dist.LogNormal(x, 1))
    with pyro.plate("N", len(data)):
        pyro.sample("z", dist.Normal(x, y), obs=data)
[9]:
data = torch.ones(10)
pyro.render_model(model, model_args=(data,), render_params=True)
[9]:
_images/model_rendering_14_0.svg

分布和约束注解

在调用 pyro.render_model 时提供 render_distributions=True,可以在生成的图中显示每个 RV 的分布。当 render_distributions=True 时,与参数相关的约束也会显示。

[10]:
data = torch.ones(10)
pyro.render_model(model, model_args=(data,), render_params=True ,render_distributions=True)
[10]:
_images/model_rendering_16_0.svg

在上面的图中,‘~’ 表示 RV 的分布,而 ‘:math:`in`’ 表示参数的约束。

重叠的非嵌套 plate

注意,重叠的非嵌套 plate 可能会被绘制成多个矩形。

[11]:
def model():
    plate1 = pyro.plate("plate1", 2, dim=-2)
    plate2 = pyro.plate("plate2", 3, dim=-1)
    with plate1:
        x = pyro.sample("x", dist.Normal(0, 1))
    with plate1, plate2:
        y = pyro.sample("y", dist.Normal(x, 1))
    with plate2:
        pyro.sample("z", dist.Normal(y.sum(-2, True), 1), obs=torch.zeros(3))
[12]:
pyro.render_model(model)
[12]:
_images/model_rendering_20_0.svg

半监督模型

Pyro 通过允许将不同集合的 *args,**kwargs 传递给模型来实现半监督模型。您可以通过传递不同的 model_args 元组列表和/或不同的 model_kwargs 列表来表示使用模型的不同方式,从而渲染半监督模型。

[13]:
def model(x, y=None):
    with pyro.plate("N", 2):
        z = pyro.sample("z", dist.Normal(0, 1))
        y = pyro.sample("y", dist.Normal(0, 1), obs=y)
        pyro.sample("x", dist.Normal(y + z, 1), obs=x)
[14]:
pyro.render_model(
    model,
    model_kwargs=[
        {"x": torch.zeros(2)},
        {"x": torch.zeros(2), "y": torch.zeros(2)},
    ]
)
[14]:
_images/model_rendering_23_0.svg

渲染确定性变量

Pyro 允许使用 pyro.deterministic 定义确定性变量。通过在 pyro.render_model 中设置 render_deterministic=True,可以渲染这些变量,如下所示

[15]:
def model_deterministic(data):
    sigma = pyro.param("sigma", torch.tensor([1.]), constraint=constraints.positive)
    mu = pyro.param("mu", torch.tensor([0.]))
    x = pyro.sample("x", dist.Normal(mu, sigma))
    log_y = pyro.sample("y", dist.Normal(x, 1))
    y = pyro.deterministic("y_deterministic", log_y.exp())
    with pyro.plate("N", len(data)):
        eps_z_loc = pyro.sample("eps_z_loc", dist.Normal(0, 1))
        z_loc = pyro.deterministic("z_loc", eps_z_loc + x, event_dim=0)
        pyro.sample("z", dist.Normal(z_loc, y), obs=data)
[16]:
data = torch.ones(10)
pyro.render_model(
    model_deterministic,
    model_args=(data,),
    render_params=True,
    render_distributions=True,
    render_deterministic=True
)
[16]:
_images/model_rendering_26_0.svg