资源算法Bert-TextClassification

Bert-TextClassification

2020-03-10 | |  38 |   0 |   0

Introduction

本仓库专注于 Bert 在文本分类领域的应用, 探索在 Bert 之上如何提高文本分类上的表现。

Requirements

下面命令还未经过完整测试, 可以参考。

推荐使用 Anconda 来管理包环境, 我采用的是 Anconda python 3.7,其余 3.0 以上应该都可以, 推荐新建一个环境来做测试。

conda create -n BertText  # 创建新环境
conda activate BertText   # 激活指定环境

Pytorch : [conda install pytorch torchvision cudatoolkit=9.0 -c pytorch](https://pytorch.org/get-started/locally/)

scikit-learn: conda install scikit-learn

pytorch-pretrained-BERT: pip install pytorch-pretrained-bert

numpy: conda install numpy

tensorboardx: pip install tensorboardX

tensorflow: pip install tensorflow

数据集

打算使用多个文本分类数据集来进行试验,以获得更佳的调参体验,主要包含情感分类, 问题分类以及主题分类三种:

  • 情感分类: 采用 IMDB, SST-2, 以及 Yelp 数据集。

  • 问题分类: 采用 TREC 和 Yahoo! Answers 数据集。

  • 主题分类: 采用 AG's News,DBPedia 以及 CNews。

仓库采用了三个数据集,分别是 SST-2 情感分类, Yelp多标签分类, THUCNews 多标签分类。

其中, THUCNews 只选取了一个子集, 该子集中包括了10个分类,每个分类6500条数据。

sst-2: 链接:https://pan.baidu.com/s/1ax9uCjdpOHDxhUhpdB0d_g  提取码:rxbi 
cnews: 链接:https://pan.baidu.com/s/19sOrAxSKn3jCIvbVoD_-ag  提取码:rstb

关于 Bert

这里,使用了 pytorch-pretrained-BERT 来加载 Bert 模型, 考虑到国内网速问题,推荐先将相关的 Bert 文件下载,主要有两种文件:

  • vocab.txt: 记录了Bert中所用词表

  • 模型参数: 主要包括预训练模型的相关参数

相关文件下载连接在 Bert

实验设置

  • 没有删除在单机多卡上的逻辑,只是删除了分布式运算的逻辑,主要是考虑到大多数实验大家都没有必要去用到分布式。

  • 删除了采用 fp16 的逻辑, 考虑到文本分类所需的资源并没有那么大, 采用 默认的32位浮点类型在大多数情况下是可以的, 没必要损失精度。其实最主要的还是精简逻辑。

  • 注意: Bert 的参数量随着文本长度的增加呈现接近线性变化的趋势, 而 THUCNews 数据集的文本长度大多在1000-4000之间,这对于大多数机器是不可承受的, 测试在单1080ti上, 文本长度设置为150左右已经是极限。

  • 注意: 我有用 tensorboard 将相关的日志信息保存,推荐采用 tensorboard 进行分析。

Results

THUCNews

注意: THUCNews 数据集中的样本长度十分的长,上面说到 Bert 本身对于序列长度十分敏感,因此我在我单1080ti下所能支持的最大长度。这也导致运行时间的线性增加,1个epoch 大概需要1个半小时到2个小时之间

python3 run_CNews.py --max_seq_length=512 --num_train_epochs=5.0 --do_train --gpu_ids="4 5 6 7" --gradient_accumulation_steps=8 --print_step=500  # gpu_ids 选择 gpu, 如果是单gpu, 选择 max_seq_length 为150较为合适(1080ti)
python3 run_CNews.py --max_seq_length=512
model_namelossaccf1
BertOrigin(base)0.08897.4097.39












BertHAN0.10397.4997.48




SST-2

python3 run_SST2.py --max_seq_length=65 --num_train_epochs=5.0 --do_train --gpu_ids="1" --gradient_accumulation_steps=8 --print_step=100  # train and test
python3 run_SST2.py --max_seq_length=65   # test
模型lossaccf1
BertOrigin(base)0.17094.45894.458
BertCNN (5,6) (base)0.14894.60794.62
BertATT (base)0.16294.21194.22
BertRCNN (base)0.14595.15195.15
BertCNNPlus (base)0.16094.50894.51

如何适配自己的数据集

对于新的数据集,只需要将你的数据集转化为对应的 tsv 格式:

sentence label

然后简历一个 run_your_dataset.py, 然后模仿 run_SST2.py 修改对应的文件夹和label_list, 其余的文件完全不需要改动, 不需要设置 Processor, 因为我将这部分重新封装了一下。

关于保存对应的结果

有同学提出要求能够最终获得 id, pred_label, true_label 三元组, 考虑到 Pytorch 中无法使用字符串,因此采用数字0,1,...,n 表示,因此如果是想要对应真实的 id 的话,需要我们将数字与id进行对应,其实很简单, Excel 排个序然后复制粘贴就行。


上一篇:SpanBERT

下一篇:AzureML-BERT

用户评价
全部评价

热门资源

  • seetafaceJNI

    项目介绍 基于中科院seetaface2进行封装的JAVA...

  • spark-corenlp

    This package wraps Stanford CoreNLP annotators ...

  • Keras-ResNeXt

    Keras ResNeXt Implementation of ResNeXt models...

  • capsnet-with-caps...

    CapsNet with capsule-wise convolution Project ...

  • inferno-boilerplate

    This is a very basic boilerplate example for pe...