资源算法bandit-nmt

bandit-nmt

2019-09-10 | |  76 |   0 |   0

bandit-nmt

THIS REPO DEMONSTRATES HOW TO INTEGRATE A POLICY GRADIENT METHOD INTO NMT. FOR A STATE-OF-THE-ART NMT CODEBASE, VISIT simple-nmt.

This is code repo for our EMNLP 2017 paper "Reinforcement Learning for Bandit Neural Machine Translation with Simulated Human Feedback", which implements the A2C algorithm on top of a neural encoder-decoder model and benchmarks the combination under simulated noisy rewards.

Requirements: - Python 3.6 - PyTorch 0.2

NOTE: as of Sep 16 2017, the code got 2x slower when I upgraded to PyTorch 2.0. This is a known issue and PyTorch is fixing it.

IMPORTANT: Set home directory (otherwise scripts will not run correctly): ~~~~

export BANDIT_HOME=$PWD export DATA=$BANDIT_HOME/data export SCRIPT=$BANDIT_HOME/scripts ~~~~

Data extraction

Download pre-processing scripts ~~~~

cd $DATA/scripts bash download_scripts.sh ~~~~

For German-English ~~~~

cd $DATA/en-de bash extract_data_de_en.sh ~~~~

NOTE: train_2014 and train_2015 highly overlap. Please be cautious when using them for other projects.

Data should be ready in $DATA/en-de/prep

TODO: Chinese-English needs segmentation

Data pre-processing

~~~~

cd $SCRIPT bash make_data.sh de en ~~~~

Pretraining

Pretrain both actor and critic ~~~~

cd $SCRIPT bash pretrain.sh en-de $YOUR_LOG_DIR ~~~~

See scripts/pretrain.sh for more details.

Pretrain actor only ~~~~

cd $BANDIT_HOME python train.py -data $YOUR_DATA -save_dir $YOUR_SAVE_DIR -end_epoch 10 ~~~~

Reinforcement training

~~~~

cd $BANDIT_HOME ~~~~

From scratch ~~~~

python train.py -data $YOUR_DATA -save_dir $YOUR_SAVE_DIR -start_reinforce 10 -end_epoch 100 -critic_pretrain_epochs 5 ~~~~

From a pretrained model ~~~~

python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -save_dir $YOUR_SAVE_DIR -start_reinforce -1 -end_epoch 100 -critic_pretrain_epochs 5 ~~~~

Perturbed rewards

For example, use thumb up/thump down reward: ~~~~

cd $BANDIT_HOME python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -save_dir $YOUR_SAVE_DIR -start_reinforce -1 -end_epoch 100 -critic_pretrain_epochs 5 -pert_func bin -pert_param 1 ~~~~

See lib/metric/PertFunction.py for more types of function.

Evaluation

~~~~

cd $BANDIT_HOME ~~~~

On heldout sets (heldout BLEU): ~~~~

python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -eval -save_dir . ~~~~

On bandit set (per-sentence BLEU): ~~~~

python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -eval_sample -save_dir . ~~~~


上一篇:captionGen

下一篇:pytorch-trpo

用户评价
全部评价

热门资源

  • Keras-ResNeXt

    Keras ResNeXt Implementation of ResNeXt models...

  • seetafaceJNI

    项目介绍 基于中科院seetaface2进行封装的JAVA...

  • spark-corenlp

    This package wraps Stanford CoreNLP annotators ...

  • capsnet-with-caps...

    CapsNet with capsule-wise convolution Project ...

  • inferno-boilerplate

    This is a very basic boilerplate example for pe...