Please checkout multigpu branch for latest (cleaner) version with newer experiments reported in EMNLP 2018 paper
==================================
PyTorch implementation of the models described in the paper Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement.
We present code for training and decoding both autoregressive and non-autoregressive models, as well as preprocessed datasets and pretrained models.
- Python 3.6
- PyTorch 0.3
- Numpy
- NLTK
- torchtext
- torchvision
- CUDA (we recommend using the latest version. The version 8.0 was used in all our experiments.)
- For preprocessing, we used the scripts from Moses and Subword-NMT.
- This code is based on NA-NMT.
The original translation corpora can be downloaded from (IWLST'16 En-De, WMT'16 En-Ro, WMT'15 En-De, MS COCO). For the preprocessed corpora and pre-trained models, see below.
Dataset | Model | |
---|---|---|
IWSLT'16 En-De | Data | Models |
WMT'16 En-Ro | Data | Models |
WMT'15 En-De | Data | Models |
MS COCO | Data | Models |
Set correct path to data in data_path()
function located in data.py
:
- For
vocab_size
, use60000
for WMT'15 En-De,40000
for the other translation datasets and10000
for MS COCO. - For
params
, usebig
for WMT'15 En-De andsmall
for the other translation datasets.
$ python run.py --dataset <dataset> --vocab_size <vocab_size> --ffw_block highway --params <params> --lr_schedule anneal --mode test --debug --load_from <checkpoint>
$ python run.py --dataset <dataset> --vocab_size <vocab_size> --ffw_block highway --params <params> --lr_schedule anneal --fast --valid_repeat_dec 20 --use_argmax --next_dec_input both --mode test --remove_repeats --debug --trg_len_option predict --use_predicted_trg_len --load_from <checkpoint>
For adaptive decoding, add the flag --adaptive_decoding jaccard
to the above.
$ python run.py --dataset <dataset> --vocab_size <vocab_size> --ffw_block highway --params <params> --lr_schedule anneal
$ python run.py --dataset <dataset> --vocab_size <vocab_size> --ffw_block highway --params <params> --lr_schedule anneal --fast --valid_repeat_dec 8 --use_argmax --next_dec_input both --denoising_prob --layerwise_denoising_weight --use_distillation
- Take a checkpoint pre-trained non-autoregressive model
- Resume training using these in addition to the same flags used in step 1:
--load_from <checkpoint> --resume --finetune_trg_len --trg_len_option predict
- Run pre-trained autoregressive model
python run.py --dataset mscoco --params big --load_vocab --mode test --n_layers 4 --ffw_block highway --debug --load_from mscoco_models_final/ar_model --batch_size 1024
- Run pre-trained non-autoregressive model
python run.py --dataset mscoco --params big --use_argmax --load_vocab --mode test --n_layers 4 --fast --ffw_block highway --debug --trg_len_option predict --use_predicted_trg_len --load_from mscoco_models_final/nar_model --batch_size 1024
- Train new autoregressive model
python run.py --dataset mscoco --params big --batch_size 1024 --load_vocab --eval_every 1000 --drop_ratio 0.5 --lr_schedule transformer --n_layers 4
- Train new non-autoregressive model
python run.py --dataset mscoco --params big --use_argmax --batch_size 1024 --load_vocab --eval_every 1000 --drop_ratio 0.5 --lr_schedule transformer --n_layers 4 --fast --use_distillation --ffw_block highway --denoising_prob 0.5 --layerwise_denoising_weight --load_encoder_from mscoco_models_final/ar_model
After training it, train the length predictor (set correct path in load_from
argument)
python run.py --dataset mscoco --params big --use_argmax --batch_size 1024 --load_vocab --mode train --n_layers 4 --fast --ffw_block highway --eval_every 1000 --drop_ratio 0.5 --drop_len_pred 0.0 --lr_schedule anneal --anneal_steps 100000 --use_distillation --load_from mscoco_models/new_nar_model --trg_len_option predict --finetune_trg_len --max_offset 20
If you find the resources in this repository useful, please consider citing:
@article{Lee:18,
author = {Jason Lee and Elman Mansimov and Kyunghyun Cho},
title = {Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement},
year = {2018},
journal = {arXiv preprint arXiv:1802.06901},
}