条件变分自编码器¶
简介¶
本教程使用 Pyro PPL 实现了 Learning Structured Output Representation using Deep Conditional Generative Models 论文,该论文于 2015 年引入了条件变分自编码器。
监督深度学习已成功应用于机器学习和计算机视觉中的许多识别问题。尽管在提供大量训练数据时,它可以很好地近似一个复杂的“多对一”函数,但当前监督深度学习方法缺乏概率推断能力,使其难以建模复杂的结构化输出表示。在这项工作中,Kihyuk Sohn、Honglak Lee 和 Xinchen Yan 开发了一种可扩展的深度条件生成模型,用于使用高斯潜变量对结构化输出变量进行建模。该模型在随机梯度变分贝叶斯框架下高效训练,并允许使用随机前向推断进行快速预测。他们将该模型命名为条件变分自编码器 (CVAE)。
CVAE 是一种条件有向图模型,其输入观测调制生成输出的高斯潜变量上的先验。其训练目标是最大化条件边缘对数似然。作者在随机梯度变分贝叶斯 (SGVB) 框架下构建了 CVAE 的变分学习目标。在实验中,他们展示了 CVAE 在使用随机推断生成多样化但真实的输出预测方面,相对于确定性神经网络同类方法的有效性。在此,我们将实现他们的概念验证:一个使用 MNIST 数据库进行结构化输出预测的人工实验设置。
问题¶
让我们将每个数字图像分成四个象限,并将一个、两个或三个象限作为输入,剩余的象限作为待预测的输出。下图展示了以一个象限作为输入的情况。
我们的目标是学习一个能够执行概率推断并从单个输入中做出多样化预测的模型。这是因为我们不是简单地像分类任务那样建模一个“多对一”函数,而是可能需要建模从单个输入到许多可能输出的映射。确定性神经网络的一个局限性在于它们只能生成一个预测。在上面的示例中,输入显示了数字的一小部分,这可能是三或五。
数据准备¶
我们使用 MNIST 数据集;第一步是准备数据。根据我们将用作输入的象限数量,我们将构建数据集和数据加载器,并使用 -1 移除未使用的像素。
class CVAEMNIST(Dataset):
def __init__(self, root, train=True, transform=None, download=False):
self.original = MNIST(root, train=train, download=download)
self.transform = transform
def __len__(self):
return len(self.original)
def __getitem__(self, item):
image, digit = self.original[item]
sample = {'original': image, 'digit': digit}
if self.transform:
sample = self.transform(sample)
return sample
class ToTensor:
def __call__(self, sample):
sample['original'] = functional.to_tensor(sample['original'])
sample['digit'] = torch.as_tensor(np.asarray(sample['digit']),
dtype=torch.int64)
return sample
class MaskImages:
"""This torchvision image transformation prepares the MNIST digits to be
used in the tutorial. Depending on the number of quadrants to be used as
inputs (1, 2, or 3), the transformation masks the remaining (3, 2, 1)
quadrant(s) setting their pixels with -1. Additionally, the transformation
adds the target output in the sample dict as the complementary of the input
"""
def __init__(self, num_quadrant_inputs, mask_with=-1):
if num_quadrant_inputs <= 0 or num_quadrant_inputs >= 4:
raise ValueError('Number of quadrants as inputs must be 1, 2 or 3')
self.num = num_quadrant_inputs
self.mask_with = mask_with
def __call__(self, sample):
tensor = sample['original'].squeeze()
out = tensor.detach().clone()
h, w = tensor.shape
# removes the bottom left quadrant from the target output
out[h // 2:, :w // 2] = self.mask_with
# if num of quadrants to be used as input is 2,
# also removes the top left quadrant from the target output
if self.num == 2:
out[:, :w // 2] = self.mask_with
# if num of quadrants to be used as input is 3,
# also removes the top right quadrant from the target output
if self.num == 3:
out[:h // 2, :] = self.mask_with
# now, sets the input as complementary
inp = tensor.clone()
inp[out != -1] = self.mask_with
sample['input'] = inp
sample['output'] = out
return sample
def get_data(num_quadrant_inputs, batch_size):
transforms = Compose([
ToTensor(),
MaskImages(num_quadrant_inputs=num_quadrant_inputs)
])
datasets, dataloaders, dataset_sizes = {}, {}, {}
for mode in ['train', 'val']:
datasets[mode] = CVAEMNIST(
'../data',
download=True,
transform=transforms,
train=mode == 'train'
)
dataloaders[mode] = DataLoader(
datasets[mode],
batch_size=batch_size,
shuffle=mode == 'train',
num_workers=0
)
dataset_sizes[mode] = len(datasets[mode])
return datasets, dataloaders, dataset_sizes
基线:确定性神经网络¶
在深入 CVAE 实现之前,我们先编写基线模型的代码。这是一个直接的实现。
class BaselineNet(nn.Module):
def __init__(self, hidden_1, hidden_2):
super().__init__()
self.fc1 = nn.Linear(784, hidden_1)
self.fc2 = nn.Linear(hidden_1, hidden_2)
self.fc3 = nn.Linear(hidden_2, 784)
self.relu = nn.ReLU()
def forward(self, x):
x = x.view(-1, 784)
hidden = self.relu(self.fc1(x))
hidden = self.relu(self.fc2(hidden))
y = torch.sigmoid(self.fc3(hidden))
return y
在论文中,作者通过比较验证集上每个图像的负(条件)对数似然 (CLL) 来比较基线 NN 与提出的 CVAE。得益于 PyTorch,计算 CLL 等同于计算经过 Sigmoid 层后的信号的二元交叉熵损失。下面的代码进行了一些小调整以利用这一点:它只计算未用 -1 掩码的像素的损失。
class MaskedBCELoss(nn.Module):
def __init__(self, masked_with=-1):
super().__init__()
self.masked_with = masked_with
def forward(self, input, target):
target = target.view(input.shape)
loss = F.binary_cross_entropy(input, target, reduction='none')
loss[target == self.masked_with] = 0
return loss.sum()
训练过程非常简单。我们在每个隐藏层使用 500 个神经元,使用学习率为 1e-3
的 Adam 优化器,并采用提前停止。请查看 Github 仓库以获取完整的实现。
用于结构化输出预测的深度条件生成模型¶
如下图所示,深度条件生成模型 (CGM) 中有三种类型的变量:输入变量 \(\bf x\)、输出变量 \(\bf y\) 和潜变量 \(\bf z\)。模型的条件生成过程在 (b) 中给出如下:对于给定的观测 \(\bf x\),\(\bf z\) 从先验分布 \(p_{\theta}({\bf z} | {\bf x})\) 中抽取,输出 \(\bf y\) 从分布 \(p_{\theta}({\bf y} | {\bf x, z})\) 中生成。与基线 NN (a) 相比,潜变量 \(\bf z\) 允许对给定输入 \(\bf x\) 的输出变量 \(\bf y\) 的条件分布中的多个模式进行建模,使得提出的 CGM 适合建模“一对多”映射。
深度 CGM 的训练目标是最大化条件边缘对数似然。目标函数通常难以处理,我们应用 SGVB 框架来训练模型。经验下界写为
其中 \(\bf z^{(l)}\) 是一个高斯潜变量,\(L\) 是样本数(或 Pyro 术语中的粒子数)。我们将此模型称为条件变分自编码器 (CVAE)。CVAE 由多个 MLP 组成,例如识别网络 \(q_{\phi}({\bf z} | \bf{x, y})\)、(条件)先验网络 \(p_{\theta}(\bf{z} | \bf{x})\) 和生成网络 \(p_{\theta}(\bf{y} | \bf{x, z})\)。在设计网络架构时,我们将 CVAE 的网络组件构建在基线 NN 的基础之上。具体来说,如上图 (d) 所示,不仅直接输入 \(\bf x\),NN 产生的初始猜测 \(\hat{y}\) 也被馈送到先验网络。
Pyro 使将此架构转换为代码变得非常容易。识别网络和(条件)先验网络是传统 VAE 设置中的编码器,而生成网络是解码器。
class Encoder(nn.Module):
def __init__(self, z_dim, hidden_1, hidden_2):
super().__init__()
self.fc1 = nn.Linear(784, hidden_1)
self.fc2 = nn.Linear(hidden_1, hidden_2)
self.fc31 = nn.Linear(hidden_2, z_dim)
self.fc32 = nn.Linear(hidden_2, z_dim)
self.relu = nn.ReLU()
def forward(self, x, y):
# put x and y together in the same image for simplification
xc = x.clone()
xc[x == -1] = y[x == -1]
xc = xc.view(-1, 784)
# then compute the hidden units
hidden = self.relu(self.fc1(xc))
hidden = self.relu(self.fc2(hidden))
# then return a mean vector and a (positive) square root covariance
# each of size batch_size x z_dim
z_loc = self.fc31(hidden)
z_scale = torch.exp(self.fc32(hidden))
return z_loc, z_scale
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_1, hidden_2):
super().__init__()
self.fc1 = nn.Linear(z_dim, hidden_1)
self.fc2 = nn.Linear(hidden_1, hidden_2)
self.fc3 = nn.Linear(hidden_2, 784)
self.relu = nn.ReLU()
def forward(self, z):
y = self.relu(self.fc1(z))
y = self.relu(self.fc2(y))
y = torch.sigmoid(self.fc3(y))
return y
class CVAE(nn.Module):
def __init__(self, z_dim, hidden_1, hidden_2, pre_trained_baseline_net):
super().__init__()
# The CVAE is composed of multiple MLPs, such as recognition network
# qφ(z|x, y), (conditional) prior network pθ(z|x), and generation
# network pθ(y|x, z). Also, CVAE is built on top of the NN: not only
# the direct input x, but also the initial guess y_hat made by the NN
# are fed into the prior network.
self.baseline_net = pre_trained_baseline_net
self.prior_net = Encoder(z_dim, hidden_1, hidden_2)
self.generation_net = Decoder(z_dim, hidden_1, hidden_2)
self.recognition_net = Encoder(z_dim, hidden_1, hidden_2)
def model(self, xs, ys=None):
# register this pytorch module and all of its sub-modules with pyro
pyro.module("generation_net", self)
batch_size = xs.shape[0]
with pyro.plate("data"):
# Prior network uses the baseline predictions as initial guess.
# This is the generative process with recurrent connection
with torch.no_grad():
# this ensures the training process does not change the
# baseline network
y_hat = self.baseline_net(xs).view(xs.shape)
# sample the handwriting style from the prior distribution, which is
# modulated by the input xs.
prior_loc, prior_scale = self.prior_net(xs, y_hat)
zs = pyro.sample('z', dist.Normal(prior_loc, prior_scale).to_event(1))
# the output y is generated from the distribution pθ(y|x, z)
loc = self.generation_net(zs)
if ys is not None:
# In training, we will only sample in the masked image
mask_loc = loc[(xs == -1).view(-1, 784)].view(batch_size, -1)
mask_ys = ys[xs == -1].view(batch_size, -1)
pyro.sample('y', dist.Bernoulli(mask_loc).to_event(1), obs=mask_ys)
else:
# In testing, no need to sample: the output is already a
# probability in [0, 1] range, which better represent pixel
# values considering grayscale. If we sample, we will force
# each pixel to be either 0 or 1, killing the grayscale
pyro.deterministic('y', loc.detach())
# return the loc so we can visualize it later
return loc
def guide(self, xs, ys=None):
with pyro.plate("data"):
if ys is None:
# at inference time, ys is not provided. In that case,
# the model uses the prior network
y_hat = self.baseline_net(xs).view(xs.shape)
loc, scale = self.prior_net(xs, y_hat)
else:
# at training time, uses the variational distribution
# q(z|x,y) = normal(loc(x,y),scale(x,y))
loc, scale = self.recognition_net(xs, ys)
pyro.sample("z", dist.Normal(loc, scale).to_event(1))
训练¶
训练代码可以在 Github 仓库中找到。点击下方视频播放按钮,观看 CVAE 在大约 40 个 epoch 的学习过程。
正如我们所见,随着训练的进行,模型学习到的后验分布持续改进:不仅损失下降,而且我们还能清楚地看到预测如何变得越来越好。
此外,在这里我们已经可以看到 CVAE 的关键优势:模型学会从单个输入中生成多个预测。对于第一个数字,输入显然是 7 的一部分。模型学习到这一点,并持续预测更清晰的 7,但具有不同的书写风格。对于第二个和第三个数字,输入可能是 3 或 5 的一部分(真实是 3),以及可能是 4 或 9 的一部分(真实是 4)。在最初的几个 epoch 中,CVAE 的预测是模糊的,随着时间的推移变得更清晰,这是预料之中的。
然而,与第一个数字不同,仅观察数字的四分之一作为输入时,很难确定第二个和第三个数字的真实情况分别是 3 和 4。到训练结束时,CVAE 生成了非常清晰和真实的预测,但它并没有强制第二个数字是 3 或 5,也没有强制第三个数字是 4 或 9。有时它会预测一个选项,有时会预测另一个。
评估结果¶
为了进行定性分析,我们在下图中可视化了生成的输出样本。正如我们所见,基线 NN 只能进行单一的确定性预测,结果是输出看起来模糊不清,在许多情况下并不逼真。相比之下,CVAE 模型生成的样本形状更逼真且多样化;有时它们甚至可以改变其身份(数字标签),例如从 3 变为 5 或从 4 变为 9,反之亦然。
我们还通过估算下表中的边缘条件对数似然 (CLL) 提供了定量证据(值越低越好)。
1 个象限 |
2 个象限 |
3 个象限 |
|
---|---|---|---|
NN (基线) |
100.4 |
61.9 |
25.4 |
CVAE (蒙特卡洛) |
71.8 |
51.0 |
24.2 |
性能差距 |
28.6 |
10.9 |
1.2 |
我们获得了与论文作者相似的结果。我们仅训练了 50 个 epoch,并使用了 3 个 epoch 的提前停止耐心;为了改善结果,我们可以让算法训练更长时间。尽管如此,我们可以观察到论文中显示出的相同效果:CVAE 的估计 CLL 显著优于基线 NN。
在 Github 上查看完整代码。
参考文献¶
[1] Learning Structured Output Representation using Deep Conditional Generative Models
, Kihyuk Sohn, Xinchen Yan, Honglak Lee