理性言语行为框架

人类语言依赖于*合作*的假设,即说话者试图向听者提供相关信息;听者可以利用这一假设,根据说话者选择的话语,对世界可能的状态进行*语用*推理。

理性言语行为框架利用概率决策和推理将这些思想形式化。

注意:此 Notebook 必须在 Pyro 4392d54a220c328ee356600fb69f82166330d3d6 或更高版本下运行。

[1]:
#first some imports
import torch
torch.set_default_dtype(torch.float64)  # double precision for numerical stability

import collections
import argparse
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
import pyro.poutine as poutine

from search_inference import HashingMarginal, memoize, Search

在定义 RSA 之前,我们首先指定一个封装了推断的辅助函数。Marginal 接收一个非归一化的随机函数,通过使用 Search 构建执行轨迹上的分布,并通过 HashingMarginal 构建返回值上的边际分布。

[2]:
def Marginal(fn):
    return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))

RSA 模型捕获了递归的社会推理——听者思考说话者,说话者思考听者……。

首先,literal_listener 简单地强制规定话语必须为真。数学上表示为

\[P_\text{Lit}(s|u) \propto {\mathcal L}(u,s)P(s)\]

代码如下

[3]:
@Marginal
def literal_listener(utterance):
    state = state_prior()
    pyro.factor("literal_meaning", 0. if meaning(utterance, state) else -999999.)
    return state

接着,合作的说话者选择一个话语,以将给定状态传达给字面听者。数学上表示为

\[P_S(u|s) \propto [P_\text{Lit}(s|u) P(u)]^\alpha\]

在下面的代码中,utterance_prior 捕获了产生话语的成本,而 pyro.sample 表达式捕获了字面听者猜测到正确状态(obs=state 表示采样值被观察到是正确的 state)。

我们使用 poutine.scale 将整个执行概率提高到 alpha 次幂——这将得到一个带有最优性参数 alpha 的 softmax 决策规则。

[4]:
@Marginal
def speaker(state):
    alpha = 1.
    with poutine.scale(scale=torch.tensor(alpha)):
        utterance = utterance_prior()
        pyro.sample("listener", literal_listener(utterance), obs=state)
    return utterance

最后,我们可以定义语用听者,他根据说话者选择的给定话语来推断可能的状态。数学上表示为

\[P_L(s|u) \propto P_S(u|s) P(s)\]

代码如下

[5]:
@Marginal
def pragmatic_listener(utterance):
    state = state_prior()
    pyro.sample("speaker", speaker(state), obs=utterance)
    return state

现在,让我们通过填充先验来建立一个简单的世界。我们假设有 4 个物体,每个物体要么是蓝色要么是红色,可能的话语是“没有是蓝色的”、“一些是蓝色的”、“所有都是蓝色的”。

我们假设蓝色物体数量和话语的先验概率是均匀分布的。

[6]:
total_number = 4

def state_prior():
    n = pyro.sample("state", dist.Categorical(probs=torch.ones(total_number+1) / total_number+1))
    return n

def utterance_prior():
    ix = pyro.sample("utt", dist.Categorical(probs=torch.ones(3) / 3))
    return ["none","some","all"][ix]

最后,意义函数(上面标记为 \(\mathcal L\)

[7]:
meanings = {
    "none": lambda N: N==0,
    "some": lambda N: N>0,
    "all": lambda N: N==total_number,
}

def meaning(utterance, state):
    return meanings[utterance](state)

现在让我们看看它是否有效:语用听者如何解释“一些”这个话语?

[8]:
#silly plotting helper:
def plot_dist(d):
    support = d.enumerate_support()
    data = [d.log_prob(s).exp().item() for s in d.enumerate_support()]
    names = list(map(str, support))

    ax = plt.subplot(111)
    width = 0.3
    bins = [x-width/2 for x in range(1, len(data) + 1)]
    ax.bar(bins,data,width=width)
    ax.set_xticks(list(range(1, len(data) + 1)))
    ax.set_xticklabels(names, rotation=45, rotation_mode="anchor", ha="right")

interp_dist = pragmatic_listener("some")
plot_dist(interp_dist)
_images/RSA-implicature_15_0.png

太棒了,我们得到了一个*数量蕴涵*:“一些”被解释为可能不包括全部 4 个。试试也看看 literal_listener——没有蕴涵。