bandit-nmt
THIS REPO DEMONSTRATES HOW TO INTEGRATE A POLICY GRADIENT METHOD INTO NMT. FOR A STATE-OF-THE-ART NMT CODEBASE, VISIT simple-nmt.
This is code repo for our EMNLP 2017 paper "Reinforcement Learning for Bandit Neural Machine Translation with Simulated Human Feedback", which implements the A2C algorithm on top of a neural encoder-decoder model and benchmarks the combination under simulated noisy rewards.
Requirements: - Python 3.6 - PyTorch 0.2
NOTE: as of Sep 16 2017, the code got 2x slower when I upgraded to PyTorch 2.0. This is a known issue and PyTorch is fixing it.
IMPORTANT: Set home directory (otherwise scripts will not run correctly): ~~~~
export BANDIT_HOME=$PWD export DATA=$BANDIT_HOME/data export SCRIPT=$BANDIT_HOME/scripts ~~~~
Data extraction
Download pre-processing scripts ~~~~
cd $DATA/scripts bash download_scripts.sh ~~~~
For German-English ~~~~
cd $DATA/en-de bash extract_data_de_en.sh ~~~~
NOTE: train_2014 and train_2015 highly overlap. Please be cautious when using them for other projects.
Data should be ready in $DATA/en-de/prep
TODO: Chinese-English needs segmentation
Data pre-processing
~~~~
cd $SCRIPT bash make_data.sh de en ~~~~
Pretraining
Pretrain both actor and critic ~~~~
cd $SCRIPT bash pretrain.sh en-de $YOUR_LOG_DIR ~~~~
See scripts/pretrain.sh
for more details.
Pretrain actor only ~~~~
cd $BANDIT_HOME python train.py -data $YOUR_DATA -save_dir $YOUR_SAVE_DIR -end_epoch 10 ~~~~
Reinforcement training
~~~~
cd $BANDIT_HOME ~~~~
From scratch ~~~~
python train.py -data $YOUR_DATA -save_dir $YOUR_SAVE_DIR -start_reinforce 10 -end_epoch 100 -critic_pretrain_epochs 5 ~~~~
From a pretrained model ~~~~
python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -save_dir $YOUR_SAVE_DIR -start_reinforce -1 -end_epoch 100 -critic_pretrain_epochs 5 ~~~~
Perturbed rewards
For example, use thumb up/thump down reward: ~~~~
cd $BANDIT_HOME python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -save_dir $YOUR_SAVE_DIR -start_reinforce -1 -end_epoch 100 -critic_pretrain_epochs 5 -pert_func bin -pert_param 1 ~~~~
See lib/metric/PertFunction.py
for more types of function.
Evaluation
~~~~
cd $BANDIT_HOME ~~~~
On heldout sets (heldout BLEU): ~~~~
python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -eval -save_dir . ~~~~
On bandit set (per-sentence BLEU): ~~~~
python train.py -data $YOUR_DATA -load_from $YOUR_MODEL -eval_sample -save_dir . ~~~~