trellisnet
This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico Kolter and Vladlen Koltun.
On the one hand, a trellis network is a temporal convolutional network with special structure, characterized by weight tying across depth and direct injection of the input into deep layers. On the other hand, we show that truncated recurrent networks are equivalent to trellis networks with special sparsity structure in their weight matrices. Thus trellis networks with general weight matrices generalize truncated recurrent networks. This allows trellis networks to serve as bridge between recurrent and convolutional architectures, benefitting from algorithmic and architectural techniques developed in either context. We leverage these relationships to design high-performing trellis networks that absorb ideas from both architectural families. Experiments demonstrate that trellis networks outperform the current state of the art on a variety of challenging benchmarks, including word-level language modeling on Penn Treebank and WikiText-103, character-level language modeling on Penn Treebank, and stress tests designed to evaluate long-term memory retention.
Our experiments were done in PyTorch. If you find our work, or this repository helpful, please consider citing our work:
@article{BaiTrellis2018, author = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun}, title = {Trellis Networks for Sequence Modeling}, journal = {arXiv:1810.06682}, year = {2018}, }
The code should be directly runnable with PyTorch 0.4.0 (although slight modifications may be needed for other versions). This repository contains the training script for the following tasks:
Sequential MNIST handwritten digit classification
Permuted Sequential MNIST that randomly permutes the pixel order in sequential MNIST
Sequential CIFAR-10 classification (more challenging, due to more intra-class variations, channel complexities and larger images)
Penn Treebank (PTB) word-level language modeling (with and without the mixture of softmax); vocabulary size 10K
Wikitext-103 (WT103) large-scale word-level language modeling; vocabulary size 268K
Penn Treebank medium-scale character-level language modeling
Note that these tasks are on very different scales, with unique properties that challenge sequence models in different ways. For example, word-level PTB is a small dataset that a typical model easily overfits, so judicious regularization is essential. WT103 is a hundred times larger, with less danger of overfitting, but with a vocabulary size of 268K that makes training more challenging (due to large embedding size).
We provide some reasonably good pre-trained weights here so that the users don't need to train from scratch. We'll update the table from time to time. Note: 1) it could take a while to load the weights; 2) if you train from scratch using different seeds (which we didn't explore), it's likely you will get better results :-)
| Description | Task | Dataset | Model | | ------------- | ----------------- | ------------------- | ------------------------------------------------------------ | | TrellisNet-LM | Language Modeling | Penn Treebank (PTB) | download (.pkl) |
To use the pre-trained weights, use the flag --load_weight [.pkl PATH]
when starting the training script.
All tasks share the same underlying TrellisNet model, which is in file trellisnet.py
(and the eventual models, including components like embedding layer, are in model.py
). As discussed in the paper, TrellisNet is able to benefit significantly from techniques developed originally for RNNs as well as temporal convolutional networks (TCNs). Some of these techniques are also included in this repository. Each task is organized in the following structure:
[TASK_NAME] / data/ logs/ [TASK_NAME].py model.py utils.py data.py
where [TASK_NAME].py
is the training script for the task (with argument flags; use -h
to see the details).
上一篇:pytorch-ntm
下一篇:RoIAlign.pytorch
还没有评论,说两句吧!
热门资源
Keras-ResNeXt
Keras ResNeXt Implementation of ResNeXt models...
seetafaceJNI
项目介绍 基于中科院seetaface2进行封装的JAVA...
spark-corenlp
This package wraps Stanford CoreNLP annotators ...
capsnet-with-caps...
CapsNet with capsule-wise convolution Project ...
inferno-boilerplate
This is a very basic boilerplate example for pe...
智能在线
400-630-6780
聆听.建议反馈
E-mail: support@tusaishared.com