带归一化流先验的变分自编码器

使用归一化流作为潜变量的先验,而非典型的标准高斯分布,是一种使变分自编码器 (VAE) 更具表达能力的简单方法。本 Notebook 演示了如何实现一个将归一化流作为 MNIST 数据集先验的 VAE。我们强烈建议您先阅读 Pyro 的 VAE 教程

在本 Notebook 中,我们使用 Zuko 来实现归一化流,但使用其他基于 PyTorch 的流库也能获得类似结果。

[1]:
import pyro
import torch
import torch.nn as nn
import torch.utils.data as data
import zuko

from pyro.contrib.zuko import ZukoToPyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
from torch import Tensor
from torchvision.datasets import MNIST
from torchvision.transforms.functional import to_tensor, to_pil_image
from tqdm import tqdm

数据

MNIST 数据集包含 28 x 28 像素的灰度图像,代表手写数字(0 到 9)。

[2]:
trainset = MNIST(root='', download=True, train=True, transform=to_tensor)
trainloader = data.DataLoader(trainset, batch_size=256, shuffle=True)
[3]:
x = [trainset[i][0] for i in range(16)]
x = torch.cat(x, dim=-1)

to_pil_image(x)
[3]:
_images/vae_flow_prior_4_0.png

模型

之前的教程 类似,我们选择一个(对角线)高斯模型作为编码器 \(q_\psi(z | x)\) 以及伯努利模型作为解码器 \(p_\phi(x | z)\)

[4]:
class GaussianEncoder(nn.Module):
    def __init__(self, features: int, latent: int):
        super().__init__()

        self.hyper = nn.Sequential(
            nn.Linear(features, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 2 * latent),
        )

    def forward(self, x: Tensor):
        phi = self.hyper(x)
        mu, log_sigma = phi.chunk(2, dim=-1)

        return pyro.distributions.Normal(mu, log_sigma.exp()).to_event(1)


class BernoulliDecoder(nn.Module):
    def __init__(self, features: int, latent: int):
        super().__init__()

        self.hyper = nn.Sequential(
            nn.Linear(latent, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, features),
        )

    def forward(self, z: Tensor):
        phi = self.hyper(z)
        rho = torch.sigmoid(phi)

        return pyro.distributions.Bernoulli(rho).to_event(1)

然而,我们选择一个 masked autoregressive flow (MAF) 作为先验 \(p_\phi(z)\) 而非典型的标准高斯分布 \(\mathcal{N}(0, I)\)。我们没有自己实现 MAF,而是从 Zuko 库中借用。因为 Zuko 分布与 Pyro 分布非常相似,一个简单的封装器 (ZukoToPyro) 就足以让 Zuko 和 Pyro 100% 兼容。

[5]:
class VAE(nn.Module):
    def __init__(self, features: int, latent: int = 16):
        super().__init__()

        self.encoder = GaussianEncoder(features, latent)
        self.decoder = BernoulliDecoder(features, latent)

        self.prior = zuko.flows.MAF(
            features=latent,
            transforms=3,
            hidden_features=(256, 256),
        )

    def model(self, x: Tensor):
        pyro.module("prior", self.prior)
        pyro.module("decoder", self.decoder)

        with pyro.plate("batch", len(x)):
            z = pyro.sample("z", ZukoToPyro(self.prior()))
            x = pyro.sample("x", self.decoder(z), obs=x)

    def guide(self, x: Tensor):
        pyro.module("encoder", self.encoder)

        with pyro.plate("batch", len(x)):
            z = pyro.sample("z", self.encoder(x))

vae = VAE(784, 16).cuda()
vae
[5]:
VAE(
  (encoder): GaussianEncoder(
    (hyper): Sequential(
      (0): Linear(in_features=784, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=1024, bias=True)
      (3): ReLU()
      (4): Linear(in_features=1024, out_features=32, bias=True)
    )
  )
  (decoder): BernoulliDecoder(
    (hyper): Sequential(
      (0): Linear(in_features=16, out_features=1024, bias=True)
      (1): ReLU()
      (2): Linear(in_features=1024, out_features=1024, bias=True)
      (3): ReLU()
      (4): Linear(in_features=1024, out_features=784, bias=True)
    )
  )
  (prior): MAF(
    (transform): LazyComposedTransform(
      (0): MaskedAutoregressiveTransform(
        (base): MonotonicAffineTransform()
        (order): [0, 1, 2, 3, 4, ..., 11, 12, 13, 14, 15]
        (hyper): MaskedMLP(
          (0): MaskedLinear(in_features=16, out_features=256, bias=True)
          (1): ReLU()
          (2): MaskedLinear(in_features=256, out_features=256, bias=True)
          (3): ReLU()
          (4): MaskedLinear(in_features=256, out_features=32, bias=True)
        )
      )
      (1): MaskedAutoregressiveTransform(
        (base): MonotonicAffineTransform()
        (order): [15, 14, 13, 12, 11, ..., 4, 3, 2, 1, 0]
        (hyper): MaskedMLP(
          (0): MaskedLinear(in_features=16, out_features=256, bias=True)
          (1): ReLU()
          (2): MaskedLinear(in_features=256, out_features=256, bias=True)
          (3): ReLU()
          (4): MaskedLinear(in_features=256, out_features=32, bias=True)
        )
      )
      (2): MaskedAutoregressiveTransform(
        (base): MonotonicAffineTransform()
        (order): [0, 1, 2, 3, 4, ..., 11, 12, 13, 14, 15]
        (hyper): MaskedMLP(
          (0): MaskedLinear(in_features=16, out_features=256, bias=True)
          (1): ReLU()
          (2): MaskedLinear(in_features=256, out_features=256, bias=True)
          (3): ReLU()
          (4): MaskedLinear(in_features=256, out_features=32, bias=True)
        )
      )
    )
    (base): Unconditional(DiagNormal(loc: torch.Size([16]), scale: torch.Size([16])))
  )
)

训练

我们使用标准的随机变分推断 (SVI) 流程来训练我们的 VAE。

[6]:
pyro.clear_param_store()

svi = SVI(vae.model, vae.guide, Adam({'lr': 1e-3}), loss=Trace_ELBO())

for epoch in (bar := tqdm(range(96))):
    losses = []

    for x, _ in trainloader:
        x = x.round().flatten(-3).cuda()

        losses.append(svi.step(x))

    losses = torch.tensor(losses)

    bar.set_postfix(loss=losses.sum().item() / len(trainset))
100%|██████████| 96/96 [24:04<00:00, 15.05s/it, loss=63.1]

训练完成后,我们可以通过从先验中抽样潜变量并对其进行解码来生成 MNIST 图像。

[7]:
z = vae.prior().sample((16,))
x = vae.decoder(z).mean.reshape(-1, 28, 28)

to_pil_image(x.movedim(0, 1).reshape(28, -1))
[7]:
_images/vae_flow_prior_12_0.png