理性言语行为框架¶
人类语言依赖于*合作*的假设,即说话者试图向听者提供相关信息;听者可以利用这一假设,根据说话者选择的话语,对世界可能的状态进行*语用*推理。
理性言语行为框架利用概率决策和推理将这些思想形式化。
注意:此 Notebook 必须在 Pyro 4392d54a220c328ee356600fb69f82166330d3d6 或更高版本下运行。
[1]:
#first some imports
import torch
torch.set_default_dtype(torch.float64) # double precision for numerical stability
import collections
import argparse
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
import pyro.poutine as poutine
from search_inference import HashingMarginal, memoize, Search
在定义 RSA 之前,我们首先指定一个封装了推断的辅助函数。Marginal
接收一个非归一化的随机函数,通过使用 Search
构建执行轨迹上的分布,并通过 HashingMarginal
构建返回值上的边际分布。
[2]:
def Marginal(fn):
return memoize(lambda *args: HashingMarginal(Search(fn).run(*args)))
RSA 模型捕获了递归的社会推理——听者思考说话者,说话者思考听者……。
首先,literal_listener
简单地强制规定话语必须为真。数学上表示为
代码如下
[3]:
@Marginal
def literal_listener(utterance):
state = state_prior()
pyro.factor("literal_meaning", 0. if meaning(utterance, state) else -999999.)
return state
接着,合作的说话者选择一个话语,以将给定状态传达给字面听者。数学上表示为
在下面的代码中,utterance_prior
捕获了产生话语的成本,而 pyro.sample
表达式捕获了字面听者猜测到正确状态(obs=state
表示采样值被观察到是正确的 state
)。
我们使用 poutine.scale
将整个执行概率提高到 alpha
次幂——这将得到一个带有最优性参数 alpha
的 softmax 决策规则。
[4]:
@Marginal
def speaker(state):
alpha = 1.
with poutine.scale(scale=torch.tensor(alpha)):
utterance = utterance_prior()
pyro.sample("listener", literal_listener(utterance), obs=state)
return utterance
最后,我们可以定义语用听者,他根据说话者选择的给定话语来推断可能的状态。数学上表示为
代码如下
[5]:
@Marginal
def pragmatic_listener(utterance):
state = state_prior()
pyro.sample("speaker", speaker(state), obs=utterance)
return state
现在,让我们通过填充先验来建立一个简单的世界。我们假设有 4 个物体,每个物体要么是蓝色要么是红色,可能的话语是“没有是蓝色的”、“一些是蓝色的”、“所有都是蓝色的”。
我们假设蓝色物体数量和话语的先验概率是均匀分布的。
[6]:
total_number = 4
def state_prior():
n = pyro.sample("state", dist.Categorical(probs=torch.ones(total_number+1) / total_number+1))
return n
def utterance_prior():
ix = pyro.sample("utt", dist.Categorical(probs=torch.ones(3) / 3))
return ["none","some","all"][ix]
最后,意义函数(上面标记为 \(\mathcal L\))
[7]:
meanings = {
"none": lambda N: N==0,
"some": lambda N: N>0,
"all": lambda N: N==total_number,
}
def meaning(utterance, state):
return meanings[utterance](state)
现在让我们看看它是否有效:语用听者如何解释“一些”这个话语?
[8]:
#silly plotting helper:
def plot_dist(d):
support = d.enumerate_support()
data = [d.log_prob(s).exp().item() for s in d.enumerate_support()]
names = list(map(str, support))
ax = plt.subplot(111)
width = 0.3
bins = [x-width/2 for x in range(1, len(data) + 1)]
ax.bar(bins,data,width=width)
ax.set_xticks(list(range(1, len(data) + 1)))
ax.set_xticklabels(names, rotation=45, rotation_mode="anchor", ha="right")
interp_dist = pragmatic_listener("some")
plot_dist(interp_dist)

太棒了,我们得到了一个*数量蕴涵*:“一些”被解释为可能不包括全部 4 个。试试也看看 literal_listener
——没有蕴涵。