用 BERT 和表示学习改进句向量
原标题: 用 BERT 和表示学习改进句向量
来源:AI研习社 链接:https://www.yanxishe.com/TextTranslation/2473
在这个实验中,我们微调了一个BERT模型,以提高其编码短文本的能力。这将为 downstream NLP task 生成更有用的语句嵌入。
虽然vanilla-BERT可以用来编码句子,但是用它生成的嵌入并不健壮。正如我们在下面看到的,模型认为相似的样本在词汇上往往比语义上更相关。输入样本中的小扰动会导致预测相似度的大变化
平均集合BERT基模型编码句子对的相似性
为了改进,我们使用了斯坦福自然语言推理数据集,该数据集包含人工标注的包含蕴涵、矛盾和中性标记的句子对。对于这些句子,我们将学习这样一种表示法,即蕴涵对之间的相似性大于矛盾对之间的相似性。
为了评估学习嵌入的质量,我们在STS和SICK-R数据集上测量了Spearman秩相关。
这个实验的计划是:
重新配置SNLI和MNLI数据集
实现数据生成器
定义损失
建立模型
准备评估pipeline
训练模型
本指南包含在标记数据上构建和训练句子编码器的代码。
对于一个熟悉的读者来说,完成本指南和训练句子编码器大约需要90分钟。用tensorflow==1.15测试代码。
这里有这个实验的代码。这一次,以前实验中的大部分代码都被重用了。我建议你先去看看。
独立版本可以在存储库中找到。
我们首先下载SNLI、MNLI、STS和SICK数据集以及预先训练的英语BERT模型。
为了更方便地处理数据集,我们对其进行了一些重新安排。对于每个唯一的锚,我们创建一个ID和一个包含anchor、蕴涵和矛盾样本的条目。每个类中缺少至少一个样本的锚被过滤掉。
一个条目如下
{ 'anchor': ["No, don't answer."], 'contradiction': ['Please respond.'], 'entailment': ["Don't respond. ", "Don't say a word. "] }
最后,加载SNLI和MNLI数据集。
为了训练这个模型,我们将抽取三个样本,包括anchor、正样本和负样本。要处理复杂的批处理生成逻辑,我们使用以下代码:
高级逻辑包含在generate_batch方法中。
批处理anchor ID是从所有可用ID中随机选择的。
anchor 样本是从其id的anchor 样本中检索的。
正样本是从他们的ID的衍生样本中提取的。
负样本从其id的矛盾样本中提取。
这些可能被认为是难例( hard negative samples ),因为它们通常在语义上与锚相似。为了减少过度拟合,我们将它们与从其他随机ID检索的随机负样本混合。
我们可以把学习句子相似度的问题作为一个排序问题。假设我们有一个由k个释义句子对x和y组成的语料库,并且想学习一个估计y是否是x的释义的函数。对于某些x,我们有一个正样本y和一个负样本k-1
P(x,y)的联合概率用分类评分函数S估计:
在训练过程中,不可能对数据集中的所有k-1负样本求和。取而代之的是,我们通过从每个批次的语料库中抽取K个响应作为负样本来近似P(x)。我们得到:
我们将最小化数据的负对数概率。因此,对于一批K三元组的损失,我们可以写下:
注:上述表达式称为Softmax Loss。
本实验采用内积作为相似函数。计算最后一个括号中表达式的代码如下
首先,我们从之前的实验中导入微调代码并构建BERT模块。
该模型对锚定样本、正样本和负样本有三个输入。使用具有平均池操作的BERT层作为共享文本编码器。
文本预处理由编码器层处理。对编码的句子计算Softmax损失。
为了方便起见,本文建立了三种模型:编码句子的enc_model、计算句子对相似度的sim_model和训练句子的trn_model。所有模型都使用共享权重。
自然语言编码器通常通过嵌入标记的句子对,测量它们之间的某种相似性,然后计算这种相似性与人类判断的相关性来进行评估。
我们使用STS 2012-2016和SICK 2014数据集来评估我们的模型。对于测试集中的所有句子对,我们计算余弦相似度。我们报告了Pearson秩与人类注释标签的相关性。
下面的回调处理求值过程,并在每次获得新的最佳结果时将提供的存储模型保存到savepath。
我们对模型进行了10个阶段的训练,每个阶段有256个批次,每批由256个三元组组成,我们在每个时段的开始都会进行评估。
trn_model.fit_generator(tr_gen._generator, validation_data=ts_gen._generator, steps_per_epoch=256, validation_steps=32, epochs=5, callbacks=callbacks)*** New best: STS_spearman_r = 0.5426*** New best: STS_pearson_r = 0.5481*** New best: SICK_spearman_r = 0.5799*** New best: SICK_pearson_r = 0.6069Epoch 1/10255/256 [============================>.] - ETA: 1s - loss: 0.6858256/256 [==============================] - 535s 2s/step - loss: 0.6844 - val_loss: 0.4366*** New best: STS_spearman_r = 0.7186*** New best: STS_pearson_r = 0.7367*** New best: SICK_spearman_r = 0.7258*** New best: SICK_pearson_r = 0.8098Epoch 2/10255/256 [============================>.] - ETA: 1s - loss: 0.3950256/256 [==============================] - 524s 2s/step - loss: 0.3950 - val_loss: 0.3700*** New best: STS_spearman_r = 0.7337*** New best: STS_pearson_r = 0.7495*** New best: SICK_spearman_r = 0.7444*** New best: SICK_pearson_r = 0.8216... Epoch 9/10255/256 [============================>.] - ETA: 1s - loss: 0.2481256/256 [==============================] - 524s 2s/step - loss: 0.2481 - val_loss: 0.2631*** New best: STS_spearman_r = 0.7536*** New best: STS_pearson_r = 0.7638*** New best: SICK_spearman_r = 0.7623*** New best: SICK_pearson_r = 0.8316Epoch 10/10255/256 [============================>.] - ETA: 1s - loss: 0.2381256/256 [==============================] - 525s 2s/step - loss: 0.2383 - val_loss: 0.2492*** New best: STS_spearman_r = 0.7547*** New best: STS_pearson_r = 0.7648*** New best: SICK_spearman_r = 0.7628*** New best: SICK_pearson_r = 0.8325
为了参考,我们可以检查句子BERT论文中的评估结果,作者在STS和SICK任务中评估了几个预先训练的句子嵌入系统。
我们的结果与发表的标准一致,在SICK-R上得到57.99斯皮尔曼秩相关得分。
经过10个阶段后,最佳Colab模型达到76.94,与通用句子编码的最佳结果76.69不相上下。
不同文本相似性任务中句子表示的余弦相似性与金标之间的Spearman秩相关。(摘自句子伯特:使用连体伯特网络的句子嵌入)
由于在批处理中的所有样本之间共享负示例,因此使用较大的批处理大小进行微调往往会将度量提高到某一点。解冻更多的编码器层也有帮助,如果你的GPU可以处理它。
一旦训练完成,我们就可以通过编码测试三元组并同时检查预测的相似性来比较基础模型和训练模型。一些例子:
以上我们提出了一种利用标记句子对改进句子嵌入的方法。
通过显式地训练模型,根据语义关系对句子对进行编码,我们能够学习更有效和更健壮的句子表示。
无论是自动评估还是手动评估,都比基线句子表示模型有了实质性的改进。
发起:唐里 校对:鸢尾 审核:唐里
参与翻译(1人):邺调
英文原文:Improving sentence embeddings with BERT and Representation Learning
一THE END一
免责声明:本文来自互联网新闻客户端自媒体,不代表本网的观点和立场。
合作及投稿邮箱:E-mail:editor@tusaishared.com
热门资源
Python 爬虫(二)...
所谓爬虫就是模拟客户端发送网络请求,获取网络响...
TensorFlow从1到2...
原文第四篇中,我们介绍了官方的入门案例MNIST,功...
TensorFlow从1到2...
“回归”这个词,既是Regression算法的名称,也代表...
机器学习中的熵、...
熵 (entropy) 这一词最初来源于热力学。1948年,克...
TensorFlow2.0(10...
前面的博客中我们说过,在加载数据和预处理数据时...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com